From 1ad2cb55827afdb324c34ce78370f5c6a1046dfa Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 20 Jan 2025 20:41:46 +0300 Subject: [PATCH] [management] Refactor peers to use store methods (#2893) --- .github/workflows/golang-test-linux.yml | 6 +- go.mod | 2 +- go.sum | 4 +- management/server/account.go | 108 +- management/server/account_test.go | 43 +- management/server/ephemeral.go | 47 +- management/server/ephemeral_test.go | 18 +- management/server/groups/manager.go | 59 +- .../server/http/handlers/networks/handler.go | 4 +- .../handlers/networks/resources_handler.go | 25 +- .../http/handlers/peers/peers_handler.go | 72 +- .../http/handlers/peers/peers_handler_test.go | 141 +-- .../peers_handler_benchmark_test.go | 62 +- .../setupkeys_handler_benchmark_test.go | 68 +- .../users_handler_benchmark_test.go | 28 +- management/server/integrated_validator.go | 30 +- .../server/integrated_validator/interface.go | 2 +- management/server/management_proto_test.go | 2 +- management/server/mock_server/account_mock.go | 24 +- management/server/peer.go | 968 ++++++++++++------ management/server/peer/peer.go | 2 +- management/server/peer_test.go | 2 +- management/server/status/error.go | 5 + management/server/store/sql_store.go | 172 +++- management/server/store/sql_store_test.go | 444 ++++++-- management/server/store/store.go | 21 +- .../server/testdata/store_policy_migrate.sql | 1 + .../testdata/store_with_expired_peers.sql | 9 +- management/server/testdata/storev1.sql | 10 +- management/server/user.go | 38 +- 30 files changed, 1587 insertions(+), 830 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ba5f66746e7..a4a3da66c8f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -262,7 +262,7 @@ jobs: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres', 'mysql' ] + store: [ 'sqlite', 'postgres' ] runs-on: ubuntu-22.04 steps: - name: Install Go @@ -314,7 +314,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management) api_integration_test: needs: [ build-cache ] @@ -363,7 +363,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management) test_client_on_docker: needs: [ build-cache ] diff --git a/go.mod b/go.mod index 88bcada0745..fa573bb9cab 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 8ba94dd6af5..a099498fb11 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= diff --git a/management/server/account.go b/management/server/account.go index eeb8b2fb81d..2c62a245360 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,6 +45,7 @@ import ( const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -85,7 +86,7 @@ type AccountManager interface { GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) @@ -105,6 +106,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -126,8 +128,8 @@ type AccountManager interface { SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) - LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -138,7 +140,7 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *types.Account) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error @@ -379,14 +381,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } updateAccountPeers := false @@ -400,7 +402,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco account.Network.Serial++ } - err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err } @@ -437,13 +439,13 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } else { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { @@ -452,7 +454,7 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. event = activity.AccountPeerInactivityExpirationDisabled am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } @@ -466,33 +468,31 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextPeerExpiration() + return peerSchedulerRetryInterval, true } - expiredPeers := account.GetExpiredPeers() var peerIDs []string for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) } - log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextPeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) + return peerSchedulerRetryInterval, true } - return account.GetNextPeerExpiration() + return am.getNextPeerExpiration(ctx, accountID) } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) { - am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } @@ -502,34 +502,33 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", accountID) - return account.GetNextInactivePeerExpiration() + log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) + return peerSchedulerRetryInterval, true } - expiredPeers := account.GetInactivePeers() var peerIDs []string - for _, peer := range expiredPeers { + for _, peer := range inactivePeers { peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) - return account.GetNextInactivePeerExpiration() + if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", accountID) + return peerSchedulerRetryInterval, true } - return account.GetNextInactivePeerExpiration() + return am.getNextInactivePeerExpiration(ctx, accountID) } } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) { - am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { - go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) { + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID)) } } @@ -665,7 +664,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -1450,7 +1449,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1497,7 +1496,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1559,17 +1558,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -1583,12 +1577,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return status.NewGetAccountError(err) - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } @@ -1609,12 +1598,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer unlockPeer() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) } @@ -1683,8 +1667,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -1695,7 +1679,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee } if peerLoginExpired(ctx, peer, settings) { - err = am.handleExpiredPeer(ctx, user, peer) + err = am.handleExpiredPeer(ctx, transaction, user, peer) if err != nil { return false, err } diff --git a/management/server/account_test.go b/management/server/account_test.go index e4f079507d2..57bc0c7571c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1450,7 +1450,6 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - userID := "account_creator" account, err := createAccount(manager, "test_account", userID, "netbird.cloud") if err != nil { t.Fatal(err) @@ -1479,7 +1478,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) + err = manager.DeletePeer(context.Background(), account.Id, peer.ID, userID) if err != nil { return } @@ -1501,7 +1500,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { assert.Equal(t, peer.Name, ev.Meta["name"]) assert.Equal(t, peer.FQDN(account.Domain), ev.Meta["fqdn"]) assert.Equal(t, userID, ev.InitiatorID) - assert.Equal(t, peer.IP.String(), ev.TargetID) + assert.Equal(t, peer.ID, ev.TargetID) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } @@ -1855,13 +1854,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ + account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1929,11 +1925,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "unable to get the account") - // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1964,7 +1957,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -3089,12 +3082,12 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 102, 110, 102, 130}, - {"Medium", 500, 100, 105, 140, 105, 190}, - {"Large", 5000, 200, 160, 200, 160, 320}, - {"Small single", 50, 10, 102, 110, 102, 130}, - {"Medium single", 500, 10, 105, 140, 105, 190}, - {"Large 5", 5000, 15, 160, 200, 160, 290}, + {"Small", 50, 5, 102, 110, 3, 20}, + {"Medium", 500, 100, 105, 140, 20, 110}, + {"Large", 5000, 200, 160, 200, 120, 260}, + {"Small single", 50, 10, 102, 110, 5, 40}, + {"Medium single", 500, 10, 105, 140, 10, 60}, + {"Large 5", 5000, 15, 160, 200, 60, 180}, } log.SetOutput(io.Discard) @@ -3163,12 +3156,12 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120, 107, 160}, - {"Medium", 500, 100, 105, 140, 105, 220}, - {"Large", 5000, 200, 180, 220, 180, 395}, - {"Small single", 50, 10, 107, 120, 105, 160}, - {"Medium single", 500, 10, 105, 140, 105, 170}, - {"Large 5", 5000, 15, 180, 220, 180, 340}, + {"Small", 50, 5, 107, 120, 10, 80}, + {"Medium", 500, 100, 105, 140, 30, 140}, + {"Large", 5000, 200, 180, 220, 140, 300}, + {"Small single", 50, 10, 107, 120, 10, 80}, + {"Medium single", 500, 10, 105, 140, 20, 60}, + {"Large 5", 5000, 15, 180, 220, 80, 200}, } log.SetOutput(io.Discard) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 3c629a0dbda..3d6d0143469 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -22,10 +21,10 @@ var ( ) type ephemeralPeer struct { - id string - account *types.Account - deadline time.Time - next *ephemeralPeer + id string + accountID string + deadline time.Time + next *ephemeralPeer } // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it @@ -106,12 +105,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) - return - } - e.peersLock.Lock() defer e.peersLock.Unlock() @@ -119,7 +112,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.ID, a, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, newDeadLine()) if e.timer == nil { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.cleanup(ctx) @@ -128,18 +121,18 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - accounts := e.store.GetAllAccounts(context.Background()) + peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) + if err != nil { + log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) + return + } + t := newDeadLine() - count := 0 - for _, a := range accounts { - for id, p := range a.Peers { - if p.Ephemeral { - count++ - e.addPeer(id, a, t) - } - } + for _, p := range peers { + e.addPeer(p.AccountID, p.ID, t) } - log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) + + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } func (e *EphemeralManager) cleanup(ctx context.Context) { @@ -172,18 +165,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) + err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } -func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { ep := &ephemeralPeer{ - id: id, - account: account, - deadline: deadline, + id: peerID, + accountID: accountID, + deadline: deadline, } if e.headPeer == nil { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index ac83724409d..df8fe98c372 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,7 +7,6 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" ) @@ -17,17 +16,14 @@ type MockStore struct { account *types.Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account { - return []*types.Account{s.account} -} - -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) { - _, ok := s.account.Peers[peerId] - if ok { - return s.account, nil +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + for _, v := range s.account.Peers { + if v.Ephemeral { + peers = append(peers, v) + } } - - return nil, status.NewPeerNotFoundError(peerId) + return peers, nil } type MocAccountManager struct { diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index f5abb212e13..cfc7ee57b72 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -13,7 +13,8 @@ import ( ) type Manager interface { - GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) @@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou } } -func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) if err != nil { return nil, err @@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string return nil, fmt.Errorf("error getting account groups: %w", err) } + return groups, nil +} + +func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + groups, err := m.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, err + } + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group @@ -130,44 +140,43 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) } -func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { - groupsInfo := []api.GroupMinimum{} - groupsChecked := make(map[string]struct{}) +func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum { + groupsInfoMap := make(map[string][]api.GroupMinimum, idCount) + groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation) for _, group := range groups { _, ok := groupsChecked[group.ID] if ok { continue } + groupsChecked[group.ID] = struct{}{} for _, pk := range group.Peers { - if pk == id { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - ResourcesCount: len(group.Resources), - } - groupsInfo = append(groupsInfo, info) - break + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), } + groupsInfoMap[pk] = append(groupsInfoMap[pk], info) } for _, rk := range group.Resources { - if rk.ID == id { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - ResourcesCount: len(group.Resources), - } - groupsInfo = append(groupsInfo, info) - break + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), } + groupsInfoMap[rk.ID] = append(groupsInfoMap[rk.ID], info) } } - return groupsInfo + return groupsInfoMap +} + +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{}, nil } -func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { return map[string]*types.Group{}, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 6b36a8fcecf..316b936115b 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) } - groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) } diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 6499bd6521d..f2dc8e3b86d 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -66,10 +66,11 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt return } + grpsInfoMap := groups.ToGroupsInfoMap(grps, len(resources)) + var resourcesResponse []*api.NetworkResource for _, resource := range resources { - groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) - resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } util.WriteJSONObject(r.Context(), w, resourcesResponse) @@ -94,10 +95,11 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt return } + grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) + var resourcesResponse []*api.NetworkResource for _, resource := range resources { - groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) - resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } util.WriteJSONObject(r.Context(), w, resourcesResponse) @@ -136,8 +138,9 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) return } - groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) - util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) + grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { @@ -162,8 +165,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { return } - groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) - util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) + grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { @@ -199,8 +203,9 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) return } - groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) - util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) + grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) + + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 7eb8e215340..cdd8026f257 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -58,8 +58,8 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { return peerToReturn, nil } -func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) +func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) return @@ -72,20 +72,21 @@ func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, u } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := groups.ToGroupsInfo(account.Groups, peer.ID) + grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) } -func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -109,16 +110,22 @@ func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := groups.ToGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + + grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -127,7 +134,7 @@ func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID _, valid := validPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -159,18 +166,11 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { case http.MethodDelete: h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodGet, http.MethodPut: - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if r.Method == http.MethodGet { - h.getPeer(r.Context(), account, peerID, userID, w) - } else { - h.updatePeer(r.Context(), account, userID, peerID, w, r) - } + case http.MethodGet: + h.getPeer(r.Context(), accountID, peerID, userID, w) + return + case http.MethodPut: + h.updatePeer(r.Context(), accountID, userID, peerID, w, r) return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -186,7 +186,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -194,18 +194,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - groupsMap := map[string]*types.Group{} grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } + grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) @@ -213,12 +204,11 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) + respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(account) + validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -281,16 +271,16 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - dnsDomain := h.accountManager.GetDNSDomain() - - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + dnsDomain := h.accountManager.GetDNSDomain() + + customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 83abc1c400e..16065a677a7 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -38,6 +38,68 @@ const ( ) func initTestMetaData(peers ...*nbpeer.Peer) *Handler { + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &types.Policy{ + ID: "policy", + AccountID: "test_id", + Name: "policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := types.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &types.Account{ + Id: "test_id", + Domain: "hotmail.com", + Peers: peersMap, + Users: map[string]*types.User{ + adminUser: types.NewAdminUser(adminUser), + regularUser: types.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*types.Group{ + "group1": { + ID: "group1", + AccountID: "test_id", + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour, + }, + Policies: []*types.Policy{policy}, + Network: &types.Network{ + Identifier: "ciclqisab2ss43jdn8q0", + Net: net.IPNet{ + IP: net.ParseIP("100.67.0.0"), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + Serial: 51, + }, + } + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -66,74 +128,31 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { + peersID := make([]string, len(peers)) + for _, peer := range peers { + peersID = append(peersID, peer.ID) + } + return []*types.Group{ + { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: peersID, + }, + }, nil + }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, + GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { + return account, nil + }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { - peersMap := make(map[string]*nbpeer.Peer) - for _, peer := range peers { - peersMap[peer.ID] = peer.Copy() - } - - policy := &types.Policy{ - ID: "policy", - AccountID: accountID, - Name: "policy", - Enabled: true, - Rules: []*types.PolicyRule{ - { - ID: "rule", - Name: "rule", - Enabled: true, - Action: "accept", - Destinations: []string{"group1"}, - Sources: []string{"group1"}, - Bidirectional: true, - Protocol: "all", - Ports: []string{"80"}, - }, - }, - } - - srvUser := types.NewRegularUser(serviceUser) - srvUser.IsServiceUser = true - - account := &types.Account{ - Id: accountID, - Domain: "hotmail.com", - Peers: peersMap, - Users: map[string]*types.User{ - adminUser: types.NewAdminUser(adminUser), - regularUser: types.NewRegularUser(regularUser), - serviceUser: srvUser, - }, - Groups: map[string]*types.Group{ - "group1": { - ID: "group1", - AccountID: accountID, - Name: "group1", - Issued: "api", - Peers: maps.Keys(peersMap), - }, - }, - Settings: &types.Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: time.Hour, - }, - Policies: []*types.Policy{policy}, - Network: &types.Network{ - Identifier: "ciclqisab2ss43jdn8q0", - Net: net.IPNet{ - IP: net.ParseIP("100.67.0.0"), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - Serial: 51, - }, - } - return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index e7637042657..2eb50e4b4e6 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdatePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 1300, MaxMsPerOpLocal: 1700, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 13000}, + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 3500}, "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, - "Peers - M": {MinMsPerOpLocal: 160, MaxMsPerOpLocal: 190, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, - "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 430, MinMsPerOpCICD: 450, MaxMsPerOpCICD: 1400}, - "Groups - L": {MinMsPerOpLocal: 1200, MaxMsPerOpLocal: 1500, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 13000}, - "Users - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2800}, - "Setup Keys - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 700, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1300}, - "Peers - XL": {MinMsPerOpLocal: 1400, MaxMsPerOpLocal: 1900, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 5000}, + "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 500}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, } log.SetOutput(io.Discard) @@ -77,14 +77,14 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, - "Peers - S": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 7, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 35, MaxMsPerOpCICD: 80}, - "Peers - L": {MinMsPerOpLocal: 120, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Groups - L": {MinMsPerOpLocal: 500, MaxMsPerOpLocal: 750, MinMsPerOpCICD: 900, MaxMsPerOpCICD: 6500}, - "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, - "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, - "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, + "Peers - XS": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70}, + "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, + "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, } log.SetOutput(io.Discard) @@ -111,14 +111,14 @@ func BenchmarkGetOnePeer(b *testing.B) { func BenchmarkGetAllPeers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 6000}, - "Peers - S": {MinMsPerOpLocal: 4, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 7, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 90}, - "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 350}, - "Groups - L": {MinMsPerOpLocal: 5000, MaxMsPerOpLocal: 5500, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, - "Users - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, - "Setup Keys - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 350, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, - "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 2200}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150}, + "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, + "Peers - L": {MinMsPerOpLocal: 110, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Groups - L": {MinMsPerOpLocal: 150, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 500}, + "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Peers - XL": {MinMsPerOpLocal: 450, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 500, MaxMsPerOpCICD: 1500}, } log.SetOutput(io.Discard) @@ -145,14 +145,14 @@ func BenchmarkGetAllPeers(b *testing.B) { func BenchmarkDeletePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, - "Peers - S": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 230}, - "Peers - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 5500}, - "Users - L": {MinMsPerOpLocal: 170, MaxMsPerOpLocal: 210, MinMsPerOpCICD: 290, MaxMsPerOpCICD: 1700}, - "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 125, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 280}, - "Peers - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 250}, + "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index bbdb4250b1a..ed643f75e7b 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ func BenchmarkCreateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, } log.SetOutput(io.Discard) @@ -81,14 +81,14 @@ func BenchmarkCreateSetupKey(b *testing.B) { func BenchmarkUpdateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, } log.SetOutput(io.Discard) @@ -128,14 +128,14 @@ func BenchmarkUpdateSetupKey(b *testing.B) { func BenchmarkGetOneSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, } log.SetOutput(io.Discard) @@ -162,8 +162,8 @@ func BenchmarkGetOneSetupKey(b *testing.B) { func BenchmarkGetAllSetupKeys(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 15}, "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, @@ -196,14 +196,14 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { func BenchmarkDeleteSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index b623419952a..549a51c0e5d 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdateUser(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 7000}, - "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 6, MaxMsPerOpCICD: 40}, - "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, - "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 700}, - "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2000}, + "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000}, + "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50}, + "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250}, + "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700}, + "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400}, "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, - "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 1000}, - "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 3500}, + "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000}, + "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, } log.SetOutput(io.Discard) @@ -119,11 +119,11 @@ func BenchmarkGetOneUser(b *testing.B) { func BenchmarkGetAllUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, - "Users - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, } @@ -152,13 +152,13 @@ func BenchmarkGetAllUsers(b *testing.B) { func BenchmarkDeleteUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 10000}, + "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000}, "Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, "Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230}, "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800}, "Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500}, - "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600}, "Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400}, } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 62e9213f700..b9827f457cf 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -76,8 +76,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { - return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + var err error + var groups []*types.Group + var peers []*nbpeer.Peer + var settings *types.Settings + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } + + peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID) + return err + }) + if err != nil { + return nil, err + } + + settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) } type MocIntegratedValidator struct { @@ -94,7 +117,8 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 22b8026aa24..ff179e3c0dc 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 8147afa44b9..0df2462f43f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -249,7 +249,7 @@ func Test_SyncProtocol(t *testing.T) { t.Fatal("expecting SyncResponse to have non-nil NetworkMap") } - if len(networkMap.GetRemotePeers()) != 3 { + if len(networkMap.GetRemotePeers()) != 4 { t.Fatalf("expecting SyncResponse to have NetworkMap with 3 remote peers, got %d", len(networkMap.GetRemotePeers())) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 042137b1b02..c8e42d20a66 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,6 +47,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) @@ -90,7 +91,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -134,7 +135,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + account, err := am.GetAccountFunc(ctx, accountID) + if err != nil { + return nil, err + } + approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -225,7 +231,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -686,9 +692,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(ctx, sync, account) + return am.SyncPeerFunc(ctx, sync, accountID) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } @@ -835,3 +841,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") } + +// GetPeerGroups mocks GetPeerGroups of the AccountManager interface +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { + if am.GetPeerGroupsFunc != nil { + return am.GetPeerGroupsFunc(ctx, accountID, peerID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index 57b38ce8130..5b0f1289958 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -13,8 +13,9 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" @@ -57,43 +58,55 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + if user.IsRegularUser() && settings.RegularUsersViewBlocked { + return []*nbpeer.Peer{}, nil + } + + accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } + peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) - regularUser := !user.HasAdminPower() && !user.IsServiceUser - - if regularUser && account.Settings.RegularUsersViewBlocked { - return peers, nil - } - - for _, peer := range account.Peers { - if regularUser && user.Id != peer.UserID { + for _, peer := range accountPeers { + if user.IsRegularUser() && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue } - p := peer.Copy() - peers = append(peers, p) - peersMap[peer.ID] = p + peers = append(peers, peer) + peersMap[peer.ID] = peer } - if !regularUser { + if user.IsAdminOrServiceUser() { return peers, nil } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, err + } + + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + // fetch all the peers that have access to the user's peers for _, peer := range peers { aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) @@ -102,53 +115,59 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } } - peers = make([]*nbpeer.Peer, 0, len(peersMap)) - for _, peer := range peersMap { - peers = append(peers, peer) - } - - return peers, nil + return maps.Values(peersMap), nil } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { start := time.Now() defer func() { log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start)) }() - peer, err := account.FindPeerByPubKey(peerPubKey) - if err != nil { - return fmt.Errorf("failed to find peer by pub key: %w", err) - } + var peer *nbpeer.Peer + var settings *types.Settings + var expired bool + var err error - expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) + if err != nil { + return err + } + + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + return err + }) if err != nil { - return fmt.Errorf("failed to update peer status and location: %w", err) + return err } - log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected) - if peer.AddedWithSSOLogin() { - if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } + + if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.UpdateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, accountID) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -159,8 +178,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context } peer.Status = newStatus - if am.geo != nil && realIP != nil { - location, err := am.geo.Lookup(realIP) + if geo != nil && realIP != nil { + location, err := geo.Lookup(realIP) if err != nil { log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) } else { @@ -168,20 +187,18 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = am.Store.SavePeerLocation(account.Id, peer) + err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } - account.UpdatePeer(peer) - log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) - err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus) if err != nil { - return false, fmt.Errorf("failed to save peer status: %w", err) + return false, err } return oldStatus.LoginExpired, nil @@ -192,174 +209,183 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } - peer := account.GetPeer(update.ID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } + var peer *nbpeer.Peer + var settings *types.Settings + var peerGroupList []string var requiresPeerUpdates bool - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) - if err != nil { - return nil, err - } + var peerLabelChanged bool + var sshChanged bool + var loginExpirationChanged bool + var inactivityExpirationChanged bool - sshEnabledUpdated := peer.SSHEnabled != update.SSHEnabled - if sshEnabledUpdated { - peer.SSHEnabled = update.SSHEnabled - event := activity.PeerSSHEnabled - if !update.SSHEnabled { - event = activity.PeerSSHDisabled + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) + if err != nil { + return err } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - } - - peerLabelUpdated := peer.Name != update.Name - if peerLabelUpdated { - peer.Name = update.Name + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } - existingLabels := account.GetPeerDNSLabels() + peerGroupList, err = getPeerGroupIDs(ctx, transaction, accountID, update.ID) + if err != nil { + return err + } - newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels) + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) if err != nil { - return nil, err + return err } - peer.DNSLabel = newLabel + if peer.Name != update.Name { + existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID) + if err != nil { + return err + } - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) - } + newLabel, err := types.GetPeerHostLabel(update.Name, existingLabels) + if err != nil { + return err + } - if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { + peer.Name = update.Name + peer.DNSLabel = newLabel + peerLabelChanged = true + } - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + if peer.SSHEnabled != update.SSHEnabled { + peer.SSHEnabled = update.SSHEnabled + sshChanged = true } - peer.LoginExpirationEnabled = update.LoginExpirationEnabled + if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { + if !peer.AddedWithSSOLogin() { + return status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + peer.LoginExpirationEnabled = update.LoginExpirationEnabled + loginExpirationChanged = true + } - event := activity.PeerLoginExpirationEnabled - if !update.LoginExpirationEnabled { - event = activity.PeerLoginExpirationDisabled + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + if !peer.AddedWithSSOLogin() { + return status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated") + } + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + inactivityExpirationChanged = true } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) + }) + if err != nil { + return nil, err + } + + if sshChanged { + event := activity.PeerSSHEnabled + if !peer.SSHEnabled { + event = activity.PeerSSHDisabled } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + if peerLabelChanged { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + } - if !peer.AddedWithSSOLogin() { - return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + if loginExpirationChanged { + event := activity.PeerLoginExpirationEnabled + if !peer.LoginExpirationEnabled { + event = activity.PeerLoginExpirationDisabled } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) + } + } + if inactivityExpirationChanged { event := activity.PeerInactivityExpirationEnabled - if !update.InactivityExpirationEnabled { + if !peer.InactivityExpirationEnabled { event = activity.PeerInactivityExpirationDisabled } am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - if peerLabelUpdated || requiresPeerUpdates { + if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) - } else if sshEnabledUpdated { - am.UpdateAccountPeer(ctx, account, peer) + } else if sshChanged { + am.UpdateAccountPeer(ctx, accountID, peer.ID) } return peer, nil } -// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error { - - // the first loop is needed to ensure all peers present under the account before modifying, otherwise - // we might have some inconsistencies - peers := make([]*nbpeer.Peer, 0, len(peerIDs)) - for _, peerID := range peerIDs { - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - peers = append(peers, peer) - } - - // the 2nd loop performs the actual modification - for _, peer := range peers { +// DeletePeer removes peer from the account by its IP +func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() - err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) + if userID != activity.SystemInitiator { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } - account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - }, - }, - NetworkMap: &types.NetworkMap{}, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) - am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } } - return nil -} - -// DeletePeer removes peer from the account by its IP -func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) if err != nil { return err } - updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) - if err != nil { - return err + if peerAccountID != accountID { + return status.NewPeerNotPartOfAccountError() } - err = am.deletePeers(ctx, account, []string{peerID}, userID) - if err != nil { - return err - } + var peer *nbpeer.Peer + var updateAccountPeers bool + var eventsToStore []func() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return err + } + + updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) return err + }) + + for _, storeEvent := range eventsToStore { + storeEvent() } if updateAccountPeers { @@ -386,7 +412,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin groups[groupID] = group.Peers } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -425,7 +451,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser := false if len(userID) > 0 { addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(userID) + accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) } else { accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } @@ -456,12 +482,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer - var groupsToAdd []string + var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var setupKeyID string var setupKeyName string var ephemeral bool + var groupsToAdd []string if addedByUser { user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) if err != nil { @@ -503,7 +530,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to get free DNS label: %w", err) } - freeIP, err := am.getFreeIP(ctx, transaction, accountID) + freeIP, err := getFreeIP(ctx, transaction, accountID) if err != nil { return fmt.Errorf("failed to get free IP: %w", err) } @@ -521,7 +548,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, SSHEnabled: false, SSHKey: peer.SSHKey, - LastLogin: util.ToPtr(registrationTime), + LastLogin: ®istrationTime, CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, @@ -551,21 +578,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) if err != nil { return fmt.Errorf("failed adding peer to All group: %w", err) } if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) if err != nil { return err } } } - err = transaction.AddPeerToAccount(ctx, newPeer) + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } @@ -587,6 +614,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } + updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID) + if err != nil { + return err + } + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) return nil }) @@ -604,48 +636,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s unlock() unlock = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, status.NewGetAccountError(err) - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) - } - groupsToAdd = append(groupsToAdd, allGroup.ID) - - newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) - if err != nil { - return nil, nil, nil, err - } - - if newGroupsAffectsPeers { + if updateAccountPeers { am.UpdateAccountPeers(ctx, accountID) } - approvedPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) - return newPeer, networkMap, postureChecks, nil + return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) { - takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID) +func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { + takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } @@ -659,72 +663,79 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, a } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { start := time.Now() defer func() { log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start)) }() - peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) + var peer *nbpeer.Peer + var peerNotValid bool + var isStatusChanged bool + var updated bool + var err error + var postureChecks []*posture.Checks + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { - return nil, nil, nil, status.NewPeerNotRegisteredError() + return nil, nil, nil, err } - if peer.UserID != "" { - user, err := account.FindUser(peer.UserID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, sync.WireGuardPubKey) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) + return status.NewPeerNotRegisteredError() } - err = checkIfPeerOwnerIsBlocked(peer, user) - if err != nil { - return nil, nil, nil, err + if peer.UserID != "" { + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) + if err != nil { + return err + } + + if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil { + return err + } } - } - if peerLoginExpired(ctx, peer, account.Settings) { - return nil, nil, nil, status.NewPeerLoginExpiredError() - } + if peerLoginExpired(ctx, peer, settings) { + return status.NewPeerLoginExpiredError() + } - updated := peer.UpdateMetaIfNew(sync.Meta) - if updated { - am.metrics.AccountManagerMetrics().CountPeerMetUpdate() - account.Peers[peer.ID] = peer - log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) - err = am.Store.SavePeer(ctx, account.Id, peer) + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) if err != nil { - return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) + return err } - } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) - } + peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return err + } - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + updated = peer.UpdateMetaIfNew(sync.Meta) + if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() + log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) + if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + return err + } + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } + } + return nil + }) if err != nil { return nil, nil, nil, err } if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { - am.UpdateAccountPeers(ctx, account.Id) - } - - if peerNotValid { - emptyMap := &types.NetworkMap{ - Network: account.Network.Copy(), - } - return peer, emptyMap, []*posture.Checks{}, nil - } - - validPeersMap, err := am.GetValidatedPeers(account) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) + am.UpdateAccountPeers(ctx, accountID) } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil + return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -772,92 +783,150 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) - if err != nil { - return nil, nil, nil, err - } + var peer *nbpeer.Peer + var updateRemotePeers bool + var isRequiresApproval bool + var isStatusChanged bool + var isPeerUpdated bool + var postureChecks []*posture.Checks settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } - // this flag prevents unnecessary calls to the persistent store. - shouldStorePeer := false - updateRemotePeers := false + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) + if err != nil { + return err + } + + // this flag prevents unnecessary calls to the persistent store. + shouldStorePeer := false + + if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return status.Errorf(status.Unauthenticated, "invalid user") + } + + changed, err := am.handleUserPeer(ctx, transaction, peer, settings) + if err != nil { + return err + } - if login.UserID != "" { - if peer.UserID != login.UserID { - log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) - return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + if changed { + shouldStorePeer = true + updateRemotePeers = true + } } - changed, err := am.handleUserPeer(ctx, peer, settings) + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID) if err != nil { - return nil, nil, nil, err + return err + } + + isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) + if err != nil { + return err } - if changed { + + isPeerUpdated = peer.UpdateMetaIfNew(login.Meta) + if isPeerUpdated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true - updateRemotePeers = true + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } } - } - groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, nil, nil, err - } + if peer.SSHKey != login.SSHKey { + peer.SSHKey = login.SSHKey + shouldStorePeer = true + } - var grps []string - for _, group := range groups { - for _, id := range group.Peers { - if id == peer.ID { - grps = append(grps, group.ID) - break + if shouldStorePeer { + if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + return err } } - } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra) + return nil + }) if err != nil { return nil, nil, nil, err } - updated := peer.UpdateMetaIfNew(login.Meta) - if updated { - am.metrics.AccountManagerMetrics().CountPeerMetUpdate() - shouldStorePeer = true + unlockPeer() + unlockPeer = nil + + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + am.UpdateAccountPeers(ctx, accountID) } - if peer.SSHKey != login.SSHKey { - peer.SSHKey = login.SSHKey - shouldStorePeer = true + return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) +} + +// getPeerPostureChecks returns the posture checks for the peer. +func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err } - if shouldStorePeer { - err = am.Store.SavePeer(ctx, accountID, peer) - if err != nil { - return nil, nil, nil, err - } + if len(policies) == 0 { + return nil, nil } - unlockPeer() - unlockPeer = nil + var peerPostureChecksIDs []string - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err + for _, policy := range policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID) + if err != nil { + return nil, err + } + + peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) } - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) if err != nil { - return nil, nil, nil, err + return nil, err } - if updateRemotePeers || isStatusChanged || (updated && len(postureChecks) > 0) { - am.UpdateAccountPeers(ctx, accountID) - } + return maps.Values(peerPostureChecks), nil +} + +// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. +func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) + sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) + if err != nil { + return nil, err + } + + for _, sourceGroup := range rule.Sources { + group, ok := sourceGroups[sourceGroup] + if !ok { + return nil, fmt.Errorf("failed to check peer in policy source group") + } + + if slices.Contains(group.Peers, peerID) { + return policy.SourcePostureChecks, nil + } + } + } + return nil, nil } // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO @@ -889,22 +958,35 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - var postureChecks []*posture.Checks +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start)) + }() if isRequiresApproval { + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + emptyMap := &types.NetworkMap{ - Network: account.Network.Copy(), + Network: network.Copy(), } return peer, emptyMap, nil, nil } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(account, peer.ID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, nil, nil, err + } + + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } @@ -913,7 +995,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -921,12 +1003,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *ty // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = am.Store.SavePeer(ctx, peer.AccountID, peer) + err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } - err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) + err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { return err } @@ -968,41 +1050,47 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + if user.IsRegularUser() && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) + if err != nil { + return nil, err } // if admin or user owns this peer, return peer - if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { + if user.IsAdminOrServiceUser() || peer.UserID == userID { return peer, nil } // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := account.FindUserPeers(userID) + userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) + if err != nil { + return nil, err + } + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -1024,7 +1112,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) return } @@ -1035,11 +1123,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } }() - peers := account.GetPeers() - - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return } @@ -1051,7 +1137,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - for _, peer := range peers { + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue @@ -1065,7 +1151,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account postureChecks, err := am.getPeerPostureChecks(account, p.ID) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) return } @@ -1080,15 +1166,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. -func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, account *types.Account, peer *nbpeer.Peer) { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) +func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { + if !am.peersUpdateManager.HasChannel(peerId) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) return } - approvedPeersMap, err := am.GetValidatedPeers(account) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peer.ID, err) + log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) + return + } + + peer := account.GetPeer(peerId) + if peer == nil { + log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) + return + } + + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) return } @@ -1097,33 +1195,239 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + postureChecks, err := am.getPeerPostureChecks(account, peerId) if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peer.ID, err) + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} +// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) + return peerSchedulerRetryInterval, true } - return labelMap + + if len(peersWithExpiry) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return peerSchedulerRetryInterval, true + } + + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) + return peerSchedulerRetryInterval, true + } + + if len(peersWithInactivity) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return peerSchedulerRetryInterval, true + } + + var nextExpiry *time.Duration + for _, peer := range peersWithInactivity { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// getExpiredPeers returns peers that have been expired. +func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, peer := range peersWithExpiry { + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// getInactivePeers returns peers that have been expired by inactivity +func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, inactivePeer := range peersWithInactivity { + inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + + return peers, nil +} + +// GetPeerGroups returns groups that the peer is part of. +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { + return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) +} + +// getPeerGroupIDs returns the IDs of the groups that the peer is part of. +func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { + groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) + if err != nil { + return nil, err + } + + groupIDs := make([]string, 0, len(groups)) + for _, group := range groups { + groupIDs = append(groupIDs, group.ID) + } + + return groupIDs, err +} + +func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { + dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + existingLabels := make(types.LookupMap) + for _, label := range dnsLabels { + existingLabels[label] = struct{}{} + } + return existingLabels, nil } // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) { - peerGroupIDs := make([]string, 0) - for _, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - peerGroupIDs = append(peerGroupIDs, group.ID) +func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { + peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) + if err != nil { + return false, err + } + return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction +} + +// deletePeers deletes all specified peers and sends updates to the remote peers. +// Returns a slice of functions to save events after successful peer deletion. +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { + var peerDeletedEvents []func() + + for _, peer := range peers { + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { + return nil, err } + + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil { + return nil, err + } + + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + }, + }, + NetworkMap: &types.NetworkMap{}, + }) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + }) } - return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) + + return peerDeletedEvents, nil +} + +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} + } + return labelMap } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 355d78ce027..199c7c89ddb 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -46,7 +46,7 @@ type Peer struct { // CreatedAt records the time the peer was created CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `gorm:"index"` // Geo location based on connection IP Location Location `gorm:"embedded;embeddedPrefix:location_"` } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2f5d0e04701..bf712f38a70 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -938,7 +938,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, - {"Extra Large", 2000, 2000, 1300, 2400, 3800, 6400}, + {"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, } log.SetOutput(io.Discard) diff --git a/management/server/status/error.go b/management/server/status/error.go index d9cab02315c..7e384922dc9 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -86,6 +86,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account +func NewPeerNotPartOfAccountError() error { + return Errorf(PermissionDenied, "peer is not part of this account") +} + // NewUserNotFoundError creates a new Error with NotFound type for a missing user func NewUserNotFoundError(userKey string) error { return Errorf(NotFound, "user not found: %s", userKey) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 7b1a634113d..900d813221e 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -313,12 +313,12 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } -func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { +func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error) } return nil @@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -368,7 +368,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return nil } -func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { +func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -376,12 +376,12 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe "peer_status_last_seen", "peer_status_connected", "peer_status_login_expired", "peer_status_required_approval", } - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Select(fieldsToUpdate). Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -391,22 +391,22 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe return nil } -func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { +func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, // updating the struct ensures the correct data format is inserted into the database. peerCopy.Location = peerWithLocation.Location - result := s.db.Model(&nbpeer.Peer{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error) } - if result.RowsAffected == 0 && s.storeEngine != MysqlStoreEngine { + if result.RowsAffected == 0 { return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) } @@ -773,9 +773,10 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) return accountID, nil } -func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { +func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { var accountID string - result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.User{}). + Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -786,6 +787,20 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { return accountID, nil } +func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { + var accountID string + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Select("account_id").Where(idQueryCondition, peerID).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "peer %s account not found", peerID) + } + return "", status.NewGetAccountFromStoreError(result.Error) + } + + return accountID, nil +} + func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID) @@ -865,7 +880,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "peer not found") + return nil, status.NewPeerNotFoundError(peerKey) } return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } @@ -1096,9 +1111,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { var group types.Group - result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&group, "account_id = ? AND name = ?", accountID, "All") if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") @@ -1114,7 +1130,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } @@ -1122,9 +1138,10 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } // AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { +func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { var group types.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). + First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1141,7 +1158,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1201,13 +1218,52 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string return nil } +// GetPeerGroups retrieves all groups assigned to a specific peer in a given account. +func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { + var groups []*types.Group + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) + + if query.Error != nil { + return nil, query.Error + } + + return groups, nil +} + +// GetAccountPeers retrieves peers for an account. +func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) + var peers []*nbpeer.Peer + + // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. + if userID == "" { + return peers, nil + } + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil } -func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.Create(peer).Error; err != nil { +func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1221,7 +1277,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength First(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "peer not found") + return nil, status.NewPeerNotFoundError(peerID) } log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get peer from store") @@ -1247,6 +1303,68 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng return peersMap, nil } +// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store") + } + + return peers, nil +} + +// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store") + } + + return peers, nil +} + +// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. +func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { + var allEphemeralPeers, batchPeers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("ephemeral = ?", true). + FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error { + allEphemeralPeers = append(allEphemeralPeers, batchPeers...) + return nil + }) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error) + return nil, fmt.Errorf("failed to retrieve ephemeral peers") + } + + return allEphemeralPeers, nil +} + +// DeletePeer removes a peer from the store. +func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete peer from store") + } + + if result.RowsAffected == 0 { + return status.NewPeerNotFoundError(peerID) + } + + return nil +} + func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) @@ -1638,7 +1756,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { var nsGroups []*nbdns.NameServerGroup - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get name server groups from store") @@ -1650,7 +1768,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength // GetNameServerGroupByID retrieves a name server group by its ID and account ID. func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { var nsGroup *nbdns.NameServerGroup - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1665,7 +1783,7 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock // SaveNameServerGroup saves a name server group to the database. func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) return status.Errorf(status.Internal, "failed to save name server group to store") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 5928b45baa7..cb51dab5179 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -422,12 +423,7 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_SavePeer(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) +func TestSqlStore_SavePeer(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -437,15 +433,16 @@ func TestSqlite_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + CreatedAt: time.Now().UTC(), } ctx := context.Background() - err = store.SavePeer(ctx, account.Id, peer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -461,23 +458,21 @@ func TestSqlite_SavePeer(t *testing.T) { updatedPeer.Status.Connected = false updatedPeer.Meta.Hostname = "updatedpeer" - err = store.SavePeer(ctx, account.Id, updatedPeer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers[peer.ID] - assert.Equal(t, updatedPeer.Status, actual.Status) assert.Equal(t, updatedPeer.Meta, actual.Meta) + assert.Equal(t, updatedPeer.Status.Connected, actual.Status.Connected) + assert.Equal(t, updatedPeer.Status.LoginExpired, actual.Status.LoginExpired) + assert.Equal(t, updatedPeer.Status.RequiresApproval, actual.Status.RequiresApproval) + assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerStatus(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) +func TestSqlStore_SavePeerStatus(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -487,7 +482,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -506,33 +501,34 @@ func TestSqlite_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) + assert.Equal(t, newStatus.Connected, actual.Connected) + assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) + assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") newStatus.Connected = true - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual = account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) + assert.Equal(t, newStatus.Connected, actual.Connected) + assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) + assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerLocation(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) +func TestSqlStore_SavePeerLocation(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -549,10 +545,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { CityName: "City", GeoNameID: 1, }, - Meta: nbpeer.PeerSystemMeta{}, + CreatedAt: time.Now().UTC(), + Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) account.Peers[peer.ID] = peer @@ -564,7 +561,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { peer.Location.CityName = "Berlin" peer.Location.GeoNameID = 2950159 - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID]) assert.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -574,7 +571,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Equal(t, peer.Location, actual) peer.ID = "non-existing-peer" - err = store.SavePeerLocation(account.Id, peer) + err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -925,47 +922,6 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } -func TestPostgresql_SavePeerStatus(t *testing.T) { - if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { - t.Skip("skip CI tests on darwin and windows") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) - require.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus.Connected, actual.Connected) -} - func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") @@ -1043,7 +999,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1056,7 +1012,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1088,7 +1044,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1100,7 +1056,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer2.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -2561,3 +2517,329 @@ func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { require.NoError(t, err) require.NotContains(t, group.Resources, *res) } + +func TestSqlStore_AddPeerToGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + groupID := "cfefqs706sqkneg59g4h" + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 0, "group should have 0 peers") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + require.NoError(t, err, "failed to add peer to group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 1, "group should have 1 peers") + require.Contains(t, group.Peers, peerID) +} + +func TestSqlStore_AddPeerToAllGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + groupID := "cfefqs706sqkneg59g3g" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.domain.test", + } + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 2, "group should have 2 peers") + require.NotContains(t, group.Peers, peer.ID) + + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + require.NoError(t, err, "failed to add peer to all group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 3, "group should have peers") + require.Contains(t, group.Peers, peer.ID) +} + +func TestSqlStore_AddPeerToAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + Key: "key", + IP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "hostname", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + Name: "peer.test", + DNSLabel: "peer", + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + SSHKey: "ssh-key", + SSHEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + LastLogin: util.ToPtr(time.Now().UTC()), + CreatedAt: time.Now().UTC(), + Ephemeral: true, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peer.ID) + require.NoError(t, err, "failed to get peer") + + assert.Equal(t, peer.ID, storedPeer.ID) + assert.Equal(t, peer.AccountID, storedPeer.AccountID) + assert.Equal(t, peer.Key, storedPeer.Key) + assert.Equal(t, peer.IP.String(), storedPeer.IP.String()) + assert.Equal(t, peer.Meta, storedPeer.Meta) + assert.Equal(t, peer.Name, storedPeer.Name) + assert.Equal(t, peer.DNSLabel, storedPeer.DNSLabel) + assert.Equal(t, peer.SSHKey, storedPeer.SSHKey) + assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) + assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) + assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) + assert.WithinDurationf(t, peer.GetLastLogin(), storedPeer.GetLastLogin().UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) + assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) + assert.Equal(t, peer.Status.LoginExpired, storedPeer.Status.LoginExpired) + assert.Equal(t, peer.Status.RequiresApproval, storedPeer.Status.RequiresApproval) + assert.WithinDurationf(t, peer.Status.LastSeen, storedPeer.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") +} + +func TestSqlStore_GetPeerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + + groups, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 1) + assert.Equal(t, groups[0].Name, "All") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h") + require.NoError(t, err) + + groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 2) +} + +func TestSqlStore_GetAccountPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 4, + }, + { + name: "should return no peers for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with expiration for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with expiration for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with expiration for a empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with inactivity for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with inactivity for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with inactivity for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/storev1.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare) + require.NoError(t, err) + require.Len(t, peers, 1) + require.True(t, peers[0].Ephemeral) +} + +func TestSqlStore_GetUserPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + userID string + expectedCount int + }{ + { + name: "should retrieve peers for existing account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 1, + }, + { + name: "should return no peers for non-existing account ID with existing user ID", + accountID: "nonexistent", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 0, + }, + { + name: "should return no peers for non-existing user ID with existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "nonexistent_user", + expectedCount: 0, + }, + { + name: "should retrieve peers for another valid account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "edafee4e-63fb-11ec-90d6-0242ac120003", + expectedCount: 2, + }, + { + name: "should return no peers for existing account ID with empty user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetUserPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.userID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_DeletePeer(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "csrnkiq7qv9d8aitqd50" + + err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) + require.NoError(t, err) + + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + require.Error(t, err) + require.Nil(t, peer) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 91ae93c7c34..245df1c3e2d 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -50,8 +50,9 @@ type Store interface { GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(userID string) (string, error) + GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) + GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) @@ -97,18 +98,24 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error + GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) + SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error + DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index 9c961e3896f..a8841179558 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -32,4 +32,5 @@ INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4h','bf1c8084-ba50-4ce7-9439-34653001fc3b','groupA','api','',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 518c484d7c4..5990a0625b1 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -27,9 +27,10 @@ CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'[]',0,0); -INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index 69194d62391..cda333d4f90 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -31,9 +31,9 @@ INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); -INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); -INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',1,'""','','',0); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user.go b/management/server/user.go index fcf3d34ff03..17770a4235e 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -287,6 +287,10 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } delete(account.Users, targetUserID) + if updateAccountPeers { + account.Network.IncSerial() + } + err = am.Store.SaveAccount(ctx, account) if err != nil { return err @@ -311,12 +315,20 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU return false, nil } - peerIDs := make([]string, 0, len(peers)) + eventsToStore, err := deletePeers(ctx, am, am.Store, account.Id, initiatorUserID, peers) + if err != nil { + return false, err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + for _, peer := range peers { - peerIDs = append(peerIDs, peer.ID) + account.DeletePeer(peer.ID) } - return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) + return hadPeers, nil } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -628,7 +640,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + if err := am.expireAndUpdatePeers(ctx, account.Id, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } @@ -955,7 +967,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -966,16 +978,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou } peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - account.UpdatePeer(peer) - if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { - return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err) - } - - log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID) + if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + return err + } am.StoreEvent( ctx, - peer.UserID, peer.ID, account.Id, + peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) } @@ -983,7 +992,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.UpdateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, accountID) } return nil } @@ -1085,6 +1094,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account deletedUsersMeta[targetUserID] = meta } + if updateAccountPeers { + account.Network.IncSerial() + } err = am.Store.SaveAccount(ctx, account) if err != nil { return fmt.Errorf("failed to delete users: %w", err)