package agent import ( "context" "fmt" "sync" "testing" "github.com/cloudwego/eino/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewSimpleMemory(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) assert.NotNil(t, memory) // Test that initial interactions are empty ctx := context.Background() interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, 0, len(interactions)) } func TestSimpleMemory_AddInteraction(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() message := &schema.Message{ Role: schema.User, Content: "test message", } err = memory.AddInteraction(ctx, message) assert.NoError(t, err) interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, 1, len(interactions)) assert.Equal(t, message, interactions[0]) } func TestSimpleMemory_AddMultipleInteractions(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() messages := []*schema.Message{ {Role: schema.User, Content: "first message"}, {Role: schema.Assistant, Content: "second message"}, {Role: schema.User, Content: "third message"}, } for _, msg := range messages { err = memory.AddInteraction(ctx, msg) assert.NoError(t, err) } interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, 3, len(interactions)) for i, msg := range messages { assert.Equal(t, msg, interactions[i]) } } func TestSimpleMemory_GetInteractions(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() // Test empty interactions interactions, err := memory.GetInteractions(ctx) assert.NoError(t, err) assert.NotNil(t, interactions) assert.Equal(t, 0, len(interactions)) // Add some interactions message1 := &schema.Message{Role: schema.User, Content: "message 1"} message2 := &schema.Message{Role: schema.Assistant, Content: "message 2"} err = memory.AddInteraction(ctx, message1) require.NoError(t, err) err = memory.AddInteraction(ctx, message2) require.NoError(t, err) interactions, err = memory.GetInteractions(ctx) assert.NoError(t, err) assert.Equal(t, 2, len(interactions)) assert.Equal(t, message1, interactions[0]) assert.Equal(t, message2, interactions[1]) } func TestSimpleMemory_Clear(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() // Add some interactions message1 := &schema.Message{Role: schema.User, Content: "message 1"} message2 := &schema.Message{Role: schema.Assistant, Content: "message 2"} err = memory.AddInteraction(ctx, message1) require.NoError(t, err) err = memory.AddInteraction(ctx, message2) require.NoError(t, err) // Verify interactions exist interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, 2, len(interactions)) // Clear interactions err = memory.Clear(ctx) assert.NoError(t, err) // Verify interactions are cleared interactions, err = memory.GetInteractions(ctx) assert.NoError(t, err) assert.Equal(t, 0, len(interactions)) } func TestSimpleMemory_NilInteraction(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() // Test adding nil interaction err = memory.AddInteraction(ctx, nil) assert.NoError(t, err) interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, 1, len(interactions)) assert.Nil(t, interactions[0]) } func TestSimpleMemory_ConcurrentAccess(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() numGoroutines := 100 messagesPerGoroutine := 10 var wg sync.WaitGroup wg.Add(numGoroutines) // Concurrent writes for i := 0; i < numGoroutines; i++ { go func(goroutineID int) { defer wg.Done() for j := 0; j < messagesPerGoroutine; j++ { message := &schema.Message{ Role: schema.User, Content: fmt.Sprintf("goroutine-%d-message-%d", goroutineID, j), } err := memory.AddInteraction(ctx, message) assert.NoError(t, err) } }(i) } wg.Wait() // Verify all interactions were added interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, numGoroutines*messagesPerGoroutine, len(interactions)) } func TestSimpleMemory_ConcurrentReadWrite(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() numReaders := 10 numWriters := 10 messagesPerWriter := 5 var wg sync.WaitGroup wg.Add(numReaders + numWriters) // Concurrent writers for i := 0; i < numWriters; i++ { go func(writerID int) { defer wg.Done() for j := 0; j < messagesPerWriter; j++ { message := &schema.Message{ Role: schema.User, Content: fmt.Sprintf("writer-%d-message-%d", writerID, j), } err := memory.AddInteraction(ctx, message) assert.NoError(t, err) } }(i) } // Concurrent readers for i := 0; i < numReaders; i++ { go func() { defer wg.Done() for j := 0; j < messagesPerWriter; j++ { interactions, err := memory.GetInteractions(ctx) assert.NoError(t, err) assert.NotNil(t, interactions) // Length can vary due to concurrent writes assert.GreaterOrEqual(t, len(interactions), 0) } }() } wg.Wait() // Final verification interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.Equal(t, numWriters*messagesPerWriter, len(interactions)) } func TestSimpleMemory_ClearDuringConcurrentAccess(t *testing.T) { memory, err := NewSimpleMemory() require.NoError(t, err) ctx := context.Background() numWriters := 5 messagesPerWriter := 10 var wg sync.WaitGroup wg.Add(numWriters + 1) // +1 for the clearer // Add some initial interactions for i := 0; i < 5; i++ { message := &schema.Message{ Role: schema.User, Content: fmt.Sprintf("initial-message-%d", i), } err := memory.AddInteraction(ctx, message) require.NoError(t, err) } // Concurrent writers for i := 0; i < numWriters; i++ { go func(writerID int) { defer wg.Done() for j := 0; j < messagesPerWriter; j++ { message := &schema.Message{ Role: schema.User, Content: fmt.Sprintf("writer-%d-message-%d", writerID, j), } err := memory.AddInteraction(ctx, message) assert.NoError(t, err) } }(i) } // Clear operation go func() { defer wg.Done() // Clear after some writes have happened err := memory.Clear(ctx) assert.NoError(t, err) }() wg.Wait() // Final state check - should be consistent interactions, err := memory.GetInteractions(ctx) require.NoError(t, err) assert.NotNil(t, interactions) // The exact number depends on timing of clear operation assert.GreaterOrEqual(t, len(interactions), 0) }