From 0dd5001dfce9f00259d7ceb74a5529c134f1f6fa Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Sat, 16 Nov 2024 19:01:05 +0530 Subject: [PATCH 1/7] feat(eval): notify subscribers when SADD is performed in evalSADD function --- internal/eval/store_eval.go | 97 ++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index 7d6f2c6db..bd41b77f1 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -4753,50 +4753,59 @@ func evalHDEL(args []string, store *dstore.Store) *EvalResponse { // Returns an integer which represents the number of members that were added to the set, not including // the members that were already present func evalSADD(args []string, store *dstore.Store) *EvalResponse { - if len(args) < 2 { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongArgumentCount("SADD"), - } - } - key := args[0] - - // Get the set object from the store. - obj := store.Get(key) - lengthOfItems := len(args[1:]) - - var count = 0 - if obj == nil { - var exDurationMs int64 = -1 - var keepttl = false - // If the object does not exist, create a new set object. - value := make(map[string]struct{}, lengthOfItems) - // Create a new object. - obj = store.NewObj(value, exDurationMs, object.ObjTypeSet) - store.Put(key, obj, dstore.WithKeepTTL(keepttl)) - } - - if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongTypeOperation, - } - } - - // Get the set object. - set := obj.Value.(map[string]struct{}) - - for _, arg := range args[1:] { - if _, ok := set[arg]; !ok { - set[arg] = struct{}{} - count++ - } - } - - return &EvalResponse{ - Result: count, - Error: nil, - } + if len(args) < 2 { + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongArgumentCount("SADD"), + } + } + key := args[0] + + // Get the set object from the store. + obj := store.Get(key) + lengthOfItems := len(args[1:]) + + var count = 0 + var set map[string]struct{} + + if obj == nil { + // If the object does not exist, create a new set + set = make(map[string]struct{}, lengthOfItems) + } else { + // Type and encoding checks + if err := object.AssertType(obj.TypeEncoding, object.ObjTypeSet); err != nil { + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongTypeOperation, + } + } + + if err := object.AssertEncoding(obj.TypeEncoding, object.ObjEncodingSetStr); err != nil { + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongTypeOperation, + } + } + + set = obj.Value.(map[string]struct{}) + } + + // Add elements to the set + for _, arg := range args[1:] { + if _, ok := set[arg]; !ok { + set[arg] = struct{}{} + count++ + } + } + + // Single Put operation at the end + obj = store.NewObj(set, -1, object.ObjTypeSet, object.ObjEncodingSetStr) + store.Put(key, obj, dstore.WithKeepTTL(false), dstore.WithPutCmd(dstore.SADD)) + + return &EvalResponse{ + Result: count, + Error: nil, + } } // evalSREM removes one or more members from a set From 66a128bd640c902494f37eee3992ab4c4f4c1078 Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Sat, 16 Nov 2024 19:01:22 +0530 Subject: [PATCH 2/7] feat(store): add SADD and SMEMBERS constants --- internal/store/constants.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/store/constants.go b/internal/store/constants.go index 1c2e8d95a..26e1140ea 100644 --- a/internal/store/constants.go +++ b/internal/store/constants.go @@ -31,6 +31,8 @@ const ( PFMERGE string = "PFMERGE" KEYSPERSHARD string = "KEYSPERSHARD" Evict string = "EVICT" + SADD string = "SADD" + SMEMBERS string = "SMEMBERS" SingleShardSize string = "SINGLEDBSIZE" SingleShardTouch string = "SINGLETOUCH" SingleShardKeys string = "SINGLEKEYS" From 2a3d8b01d0947140be884429ef6c5d2ab30fe242 Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Sat, 16 Nov 2024 19:01:39 +0530 Subject: [PATCH 3/7] feat(watchmanager): include dstore.SADD in affectedCmdMap --- internal/watchmanager/watch_manager.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/watchmanager/watch_manager.go b/internal/watchmanager/watch_manager.go index 99eda75c9..37aee3efb 100644 --- a/internal/watchmanager/watch_manager.go +++ b/internal/watchmanager/watch_manager.go @@ -50,6 +50,7 @@ var ( dstore.ZAdd: {dstore.ZRange: struct{}{}}, dstore.PFADD: {dstore.PFCOUNT: struct{}{}}, dstore.PFMERGE: {dstore.PFCOUNT: struct{}{}}, + dstore.SADD: {dstore.SMEMBERS: struct{}{}}, } ) From cc20c66f33432c429c88fa97434b6340cbb041fc Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Sat, 16 Nov 2024 19:01:56 +0530 Subject: [PATCH 4/7] feat(worker): add CmdSMembersWatch to CommandsMeta and watch commands constants --- internal/commandhandler/cmd_meta.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/commandhandler/cmd_meta.go b/internal/commandhandler/cmd_meta.go index cae1bb6b7..3cb4d7753 100644 --- a/internal/commandhandler/cmd_meta.go +++ b/internal/commandhandler/cmd_meta.go @@ -125,6 +125,7 @@ const ( CmdSrem = "SREM" CmdScard = "SCARD" CmdSmembers = "SMEMBERS" + CmdSMembersWatch = "SMEMBERS.WATCH" CmdDump = "DUMP" CmdRestore = "RESTORE" CmdGeoAdd = "GEOADD" @@ -676,6 +677,9 @@ var CommandsMeta = map[string]CmdMeta{ CmdPFCountWatch: { CmdType: Watch, }, + CmdSMembersWatch: { + CmdType: Watch, + }, // Unwatch commands CmdGetUnWatch: { From e9b356a162cda79bc18686a8e76d3658a8223cf2 Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Sat, 16 Nov 2024 19:02:14 +0530 Subject: [PATCH 5/7] test(integration): add tests for SMEMBERSWATCH and SMEMBERSWATCHWithSDK --- .../commands/resp/smemberswatch_test.go | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 integration_tests/commands/resp/smemberswatch_test.go diff --git a/integration_tests/commands/resp/smemberswatch_test.go b/integration_tests/commands/resp/smemberswatch_test.go new file mode 100644 index 000000000..90b5ad603 --- /dev/null +++ b/integration_tests/commands/resp/smemberswatch_test.go @@ -0,0 +1,197 @@ +package resp + +import ( + "context" + "fmt" + "net" + "sort" + "testing" + "time" + + "github.com/dicedb/dice/internal/clientio" + dicedb "github.com/dicedb/dicedb-go" + "gotest.tools/v3/assert" +) + +type smembersWatchTestCase struct { + key string + val string + result any +} + +const ( + smembersCommand = "SMEMBERS" + smembersWatchKey = "smemberswatchkey" + smembersWatchQuery = "SMEMBERS.WATCH %s" + smembersWatchFingerPrint = "3660753939" +) + +var smembersWatchTestCases = []smembersWatchTestCase{ + {smembersWatchKey, "member1", []any{"member1"}}, + {smembersWatchKey, "member2", []any{"member1", "member2"}}, + {smembersWatchKey, "member3", []any{"member1", "member2", "member3"}}, +} + +func TestSMEMBERSWATCH(t *testing.T) { + publisher := getLocalConnection() + subscribers := setupSubscribers(3) + + FireCommand(publisher, fmt.Sprintf("DEL %s", smembersWatchKey)) + + defer func() { + if err := publisher.Close(); err != nil { + t.Errorf("Error closing publisher connection: %v", err) + } + for _, sub := range subscribers { + time.Sleep(100 * time.Millisecond) + if err := sub.Close(); err != nil { + t.Errorf("Error closing subscriber connection: %v", err) + } + } + }() + + respParsers := setUpSmembersRespParsers(t, subscribers) + + t.Run("Basic Set Operations", func(t *testing.T) { + testSetOperations(t, publisher, respParsers) + }) +} + +func sortSlice(v any) any { + switch v := v.(type) { + case []interface{}: + sorted := make([]interface{}, len(v)) + copy(sorted, v) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].(string) < sorted[j].(string) + }) + return sorted + case []string: + sorted := make([]string, len(v)) + copy(sorted, v) + sort.Strings(sorted) + return sorted + default: + return v + } +} + +func setUpSmembersRespParsers(t *testing.T, subscribers []net.Conn) []*clientio.RESPParser { + respParsers := make([]*clientio.RESPParser, len(subscribers)) + for i, subscriber := range subscribers { + rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf(smembersWatchQuery, smembersWatchKey)) + assert.Assert(t, rp != nil) + respParsers[i] = rp + + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + } + return respParsers +} + +func testSetOperations(t *testing.T, publisher net.Conn, respParsers []*clientio.RESPParser) { + for _, tc := range smembersWatchTestCases { + res := FireCommand(publisher, fmt.Sprintf("SADD %s %s", tc.key, tc.val)) + assert.Equal(t, int64(1), res) + verifySmembersWatchResults(t, respParsers, tc.result) + } +} + +func verifySmembersWatchResults(t *testing.T, respParsers []*clientio.RESPParser, expected any) { + for _, rp := range respParsers { + v, err := rp.DecodeOne() + assert.NilError(t, err) + castedValue, ok := v.([]interface{}) + if !ok { + t.Errorf("Type assertion to []interface{} failed for value: %v", v) + } + assert.Equal(t, 3, len(castedValue)) + assert.Equal(t, smembersCommand, castedValue[0]) + assert.Equal(t, smembersWatchFingerPrint, castedValue[1]) + assert.DeepEqual(t, sortSlice(expected), sortSlice(castedValue[2])) + } +} + +type smembersWatchSDKTestCase struct { + key string + val string + result []string +} + +var smembersWatchSDKTestCases = []smembersWatchSDKTestCase{ + {smembersWatchKey, "member1", []string{"member1"}}, + {smembersWatchKey, "member2", []string{"member1", "member2"}}, + {smembersWatchKey, "member3", []string{"member1", "member2", "member3"}}, +} + +func TestSMEMBERSWATCHWithSDK(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + publisher := getLocalSdk() + subscribers := setupSubscribersSDK(3) + defer cleanupSubscribersSDK(subscribers) + + publisher.Del(ctx, smembersWatchKey) + + channels := setUpSmembersWatchChannelsSDK(t, ctx, subscribers) + + t.Run("Basic Set Operations", func(t *testing.T) { + testSetOperationsSDK(t, ctx, channels, publisher) + }) +} + +func setUpSmembersWatchChannelsSDK(t *testing.T, ctx context.Context, subscribers []WatchSubscriber) []<-chan *dicedb.WatchResult { + channels := make([]<-chan *dicedb.WatchResult, len(subscribers)) + for i, subscriber := range subscribers { + watch := subscriber.client.WatchConn(ctx) + subscribers[i].watch = watch + assert.Assert(t, watch != nil) + firstMsg, err := watch.Watch(ctx, smembersCommand, smembersWatchKey) + assert.NilError(t, err) + assert.Equal(t, firstMsg.Command, smembersCommand) + channels[i] = watch.Channel() + } + return channels +} + +func testSetOperationsSDK(t *testing.T, ctx context.Context, channels []<-chan *dicedb.WatchResult, publisher *dicedb.Client) { + for _, tc := range smembersWatchSDKTestCases { + err := publisher.SAdd(ctx, tc.key, tc.val).Err() + assert.NilError(t, err) + verifySmembersWatchResultsSDK(t, channels, tc.result) + } +} + +func verifySmembersWatchResultsSDK(t *testing.T, channels []<-chan *dicedb.WatchResult, expected []string) { + for _, channel := range channels { + select { + case v := <-channel: + assert.Equal(t, smembersCommand, v.Command) + assert.Equal(t, smembersWatchFingerPrint, v.Fingerprint) + + received, ok := v.Data.([]interface{}) + if !ok { + t.Fatalf("Expected []interface{}, got %T", v.Data) + } + + receivedStrings := make([]string, len(received)) + for i, item := range received { + str, ok := item.(string) + if !ok { + t.Fatalf("Expected string, got %T", item) + } + receivedStrings[i] = str + } + + assert.DeepEqual(t, sortSlice(expected), sortSlice(receivedStrings)) + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for watch result") + } + } +} \ No newline at end of file From 142308058b835b1f38b011c4b0fdef1147d5471f Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Mon, 18 Nov 2024 22:57:05 +0530 Subject: [PATCH 6/7] test(integration): remove sortSlice function, introduce testifyAssert.ElementMatch --- .../commands/resp/smemberswatch_test.go | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/integration_tests/commands/resp/smemberswatch_test.go b/integration_tests/commands/resp/smemberswatch_test.go index 90b5ad603..5aa4dd7f6 100644 --- a/integration_tests/commands/resp/smemberswatch_test.go +++ b/integration_tests/commands/resp/smemberswatch_test.go @@ -4,13 +4,13 @@ import ( "context" "fmt" "net" - "sort" "testing" "time" "github.com/dicedb/dice/internal/clientio" dicedb "github.com/dicedb/dicedb-go" "gotest.tools/v3/assert" + testifyAssert "github.com/stretchr/testify/assert" ) type smembersWatchTestCase struct { @@ -57,25 +57,6 @@ func TestSMEMBERSWATCH(t *testing.T) { }) } -func sortSlice(v any) any { - switch v := v.(type) { - case []interface{}: - sorted := make([]interface{}, len(v)) - copy(sorted, v) - sort.Slice(sorted, func(i, j int) bool { - return sorted[i].(string) < sorted[j].(string) - }) - return sorted - case []string: - sorted := make([]string, len(v)) - copy(sorted, v) - sort.Strings(sorted) - return sorted - default: - return v - } -} - func setUpSmembersRespParsers(t *testing.T, subscribers []net.Conn) []*clientio.RESPParser { respParsers := make([]*clientio.RESPParser, len(subscribers)) for i, subscriber := range subscribers { @@ -113,7 +94,7 @@ func verifySmembersWatchResults(t *testing.T, respParsers []*clientio.RESPParser assert.Equal(t, 3, len(castedValue)) assert.Equal(t, smembersCommand, castedValue[0]) assert.Equal(t, smembersWatchFingerPrint, castedValue[1]) - assert.DeepEqual(t, sortSlice(expected), sortSlice(castedValue[2])) + testifyAssert.ElementsMatch(t, expected, castedValue[2]) } } @@ -189,7 +170,7 @@ func verifySmembersWatchResultsSDK(t *testing.T, channels []<-chan *dicedb.Watch receivedStrings[i] = str } - assert.DeepEqual(t, sortSlice(expected), sortSlice(receivedStrings)) + testifyAssert.ElementsMatch(t, expected, receivedStrings) case <-time.After(defaultTimeout): t.Fatal("timeout waiting for watch result") } From 9edd2a5b3d0887cb385d52d12cca0bc5435a6e14 Mon Sep 17 00:00:00 2001 From: superiorsd10 Date: Tue, 24 Dec 2024 22:20:01 +0530 Subject: [PATCH 7/7] feat(eval): made changes according to the latest rebase --- internal/eval/store_eval.go | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index bd41b77f1..8c1566dd4 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -4772,15 +4772,8 @@ func evalSADD(args []string, store *dstore.Store) *EvalResponse { // If the object does not exist, create a new set set = make(map[string]struct{}, lengthOfItems) } else { - // Type and encoding checks - if err := object.AssertType(obj.TypeEncoding, object.ObjTypeSet); err != nil { - return &EvalResponse{ - Result: nil, - Error: diceerrors.ErrWrongTypeOperation, - } - } - - if err := object.AssertEncoding(obj.TypeEncoding, object.ObjEncodingSetStr); err != nil { + // Type checks + if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil { return &EvalResponse{ Result: nil, Error: diceerrors.ErrWrongTypeOperation, @@ -4799,7 +4792,7 @@ func evalSADD(args []string, store *dstore.Store) *EvalResponse { } // Single Put operation at the end - obj = store.NewObj(set, -1, object.ObjTypeSet, object.ObjEncodingSetStr) + obj = store.NewObj(set, -1, object.ObjTypeSet) store.Put(key, obj, dstore.WithKeepTTL(false), dstore.WithPutCmd(dstore.SADD)) return &EvalResponse{