diff --git a/integration_tests/commands/resp/smemberswatch_test.go b/integration_tests/commands/resp/smemberswatch_test.go new file mode 100644 index 000000000..5aa4dd7f6 --- /dev/null +++ b/integration_tests/commands/resp/smemberswatch_test.go @@ -0,0 +1,178 @@ +package resp + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/dicedb/dice/internal/clientio" + dicedb "github.com/dicedb/dicedb-go" + "gotest.tools/v3/assert" + testifyAssert "github.com/stretchr/testify/assert" +) + +type smembersWatchTestCase struct { + key string + val string + result any +} + +const ( + smembersCommand = "SMEMBERS" + smembersWatchKey = "smemberswatchkey" + smembersWatchQuery = "SMEMBERS.WATCH %s" + smembersWatchFingerPrint = "3660753939" +) + +var smembersWatchTestCases = []smembersWatchTestCase{ + {smembersWatchKey, "member1", []any{"member1"}}, + {smembersWatchKey, "member2", []any{"member1", "member2"}}, + {smembersWatchKey, "member3", []any{"member1", "member2", "member3"}}, +} + +func TestSMEMBERSWATCH(t *testing.T) { + publisher := getLocalConnection() + subscribers := setupSubscribers(3) + + FireCommand(publisher, fmt.Sprintf("DEL %s", smembersWatchKey)) + + defer func() { + if err := publisher.Close(); err != nil { + t.Errorf("Error closing publisher connection: %v", err) + } + for _, sub := range subscribers { + time.Sleep(100 * time.Millisecond) + if err := sub.Close(); err != nil { + t.Errorf("Error closing subscriber connection: %v", err) + } + } + }() + + respParsers := setUpSmembersRespParsers(t, subscribers) + + t.Run("Basic Set Operations", func(t *testing.T) { + testSetOperations(t, publisher, respParsers) + }) +} + +func setUpSmembersRespParsers(t *testing.T, subscribers []net.Conn) []*clientio.RESPParser { + respParsers := make([]*clientio.RESPParser, len(subscribers)) + for i, subscriber := range subscribers { + rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf(smembersWatchQuery, smembersWatchKey)) + assert.Assert(t, rp != nil) + respParsers[i] = rp + + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + } + return respParsers +} + +func testSetOperations(t *testing.T, publisher net.Conn, respParsers []*clientio.RESPParser) { + for _, tc := range smembersWatchTestCases { + res := FireCommand(publisher, fmt.Sprintf("SADD %s %s", tc.key, tc.val)) + assert.Equal(t, int64(1), res) + verifySmembersWatchResults(t, respParsers, tc.result) + } +} + +func verifySmembersWatchResults(t *testing.T, respParsers []*clientio.RESPParser, expected any) { + for _, rp := range respParsers { + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + assert.Equal(t, smembersCommand, castedValue[0]) + assert.Equal(t, smembersWatchFingerPrint, castedValue[1]) + testifyAssert.ElementsMatch(t, expected, castedValue[2]) + } +} + +type smembersWatchSDKTestCase struct { + key string + val string + result []string +} + +var smembersWatchSDKTestCases = []smembersWatchSDKTestCase{ + {smembersWatchKey, "member1", []string{"member1"}}, + {smembersWatchKey, "member2", []string{"member1", "member2"}}, + {smembersWatchKey, "member3", []string{"member1", "member2", "member3"}}, +} + +func TestSMEMBERSWATCHWithSDK(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + publisher := getLocalSdk() + subscribers := setupSubscribersSDK(3) + defer cleanupSubscribersSDK(subscribers) + + publisher.Del(ctx, smembersWatchKey) + + channels := setUpSmembersWatchChannelsSDK(t, ctx, subscribers) + + t.Run("Basic Set Operations", func(t *testing.T) { + testSetOperationsSDK(t, ctx, channels, publisher) + }) +} + +func setUpSmembersWatchChannelsSDK(t *testing.T, ctx context.Context, subscribers []WatchSubscriber) []<-chan *dicedb.WatchResult { + channels := make([]<-chan *dicedb.WatchResult, len(subscribers)) + for i, subscriber := range subscribers { + watch := subscriber.client.WatchConn(ctx) + subscribers[i].watch = watch + assert.Assert(t, watch != nil) + firstMsg, err := watch.Watch(ctx, smembersCommand, smembersWatchKey) + assert.NilError(t, err) + assert.Equal(t, firstMsg.Command, smembersCommand) + channels[i] = watch.Channel() + } + return channels +} + +func testSetOperationsSDK(t *testing.T, ctx context.Context, channels []<-chan *dicedb.WatchResult, publisher *dicedb.Client) { + for _, tc := range smembersWatchSDKTestCases { + err := publisher.SAdd(ctx, tc.key, tc.val).Err() + assert.NilError(t, err) + verifySmembersWatchResultsSDK(t, channels, tc.result) + } +} + +func verifySmembersWatchResultsSDK(t *testing.T, channels []<-chan *dicedb.WatchResult, expected []string) { + for _, channel := range channels { + select { + case v := <-channel: + assert.Equal(t, smembersCommand, v.Command) + assert.Equal(t, smembersWatchFingerPrint, v.Fingerprint) + + received, ok := v.Data.([]interface{}) + if !ok { + t.Fatalf("Expected []interface{}, got %T", v.Data) + } + + receivedStrings := make([]string, len(received)) + for i, item := range received { + str, ok := item.(string) + if !ok { + t.Fatalf("Expected string, got %T", item) + } + receivedStrings[i] = str + } + + testifyAssert.ElementsMatch(t, expected, receivedStrings) + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for watch result") + } + } +} \ No newline at end of file diff --git a/internal/commandhandler/cmd_meta.go b/internal/commandhandler/cmd_meta.go index cae1bb6b7..3cb4d7753 100644 --- a/internal/commandhandler/cmd_meta.go +++ b/internal/commandhandler/cmd_meta.go @@ -125,6 +125,7 @@ const ( CmdSrem = "SREM" CmdScard = "SCARD" CmdSmembers = "SMEMBERS" + CmdSMembersWatch = "SMEMBERS.WATCH" CmdDump = "DUMP" CmdRestore = "RESTORE" CmdGeoAdd = "GEOADD" @@ -676,6 +677,9 @@ var CommandsMeta = map[string]CmdMeta{ CmdPFCountWatch: { CmdType: Watch, }, + CmdSMembersWatch: { + CmdType: Watch, + }, // Unwatch commands CmdGetUnWatch: { diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index 7d6f2c6db..8c1566dd4 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -4753,50 +4753,52 @@ func evalHDEL(args []string, store *dstore.Store) *EvalResponse { // Returns an integer which represents the number of members that were added to the set, not including // the members that were already present func evalSADD(args []string, store *dstore.Store) *EvalResponse { - if len(args) < 2 { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongArgumentCount("SADD"), - } - } - key := args[0] - - // Get the set object from the store. - obj := store.Get(key) - lengthOfItems := len(args[1:]) - - var count = 0 - if obj == nil { - var exDurationMs int64 = -1 - var keepttl = false - // If the object does not exist, create a new set object. - value := make(map[string]struct{}, lengthOfItems) - // Create a new object. - obj = store.NewObj(value, exDurationMs, object.ObjTypeSet) - store.Put(key, obj, dstore.WithKeepTTL(keepttl)) - } - - if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongTypeOperation, - } - } - - // Get the set object. - set := obj.Value.(map[string]struct{}) - - for _, arg := range args[1:] { - if _, ok := set[arg]; !ok { - set[arg] = struct{}{} - count++ - } - } - - return &EvalResponse{ - Result: count, - Error: nil, - } + if len(args) < 2 { + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongArgumentCount("SADD"), + } + } + key := args[0] + + // Get the set object from the store. + obj := store.Get(key) + lengthOfItems := len(args[1:]) + + var count = 0 + var set map[string]struct{} + + if obj == nil { + // If the object does not exist, create a new set + set = make(map[string]struct{}, lengthOfItems) + } else { + // Type checks + if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil { + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongTypeOperation, + } + } + + set = obj.Value.(map[string]struct{}) + } + + // Add elements to the set + for _, arg := range args[1:] { + if _, ok := set[arg]; !ok { + set[arg] = struct{}{} + count++ + } + } + + // Single Put operation at the end + obj = store.NewObj(set, -1, object.ObjTypeSet) + store.Put(key, obj, dstore.WithKeepTTL(false), dstore.WithPutCmd(dstore.SADD)) + + return &EvalResponse{ + Result: count, + Error: nil, + } } // evalSREM removes one or more members from a set diff --git a/internal/store/constants.go b/internal/store/constants.go index 1c2e8d95a..26e1140ea 100644 --- a/internal/store/constants.go +++ b/internal/store/constants.go @@ -31,6 +31,8 @@ const ( PFMERGE string = "PFMERGE" KEYSPERSHARD string = "KEYSPERSHARD" Evict string = "EVICT" + SADD string = "SADD" + SMEMBERS string = "SMEMBERS" SingleShardSize string = "SINGLEDBSIZE" SingleShardTouch string = "SINGLETOUCH" SingleShardKeys string = "SINGLEKEYS" diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 99eda75c9..37aee3efb 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -50,6 +50,7 @@ var ( dstore.ZAdd: {dstore.ZRange: struct{}{}}, dstore.PFADD: {dstore.PFCOUNT: struct{}{}}, dstore.PFMERGE: {dstore.PFCOUNT: struct{}{}}, + dstore.SADD: {dstore.SMEMBERS: struct{}{}}, } )