Skip to content

Commit

Permalink
[management] Refactor peers to use store methods (#2893)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmmbaga authored Jan 20, 2025
1 parent c619bf5 commit 1ad2cb5
Show file tree
Hide file tree
Showing 30 changed files with 1,587 additions and 830 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/golang-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ jobs:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
Expand Down Expand Up @@ -314,7 +314,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8

- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)

api_integration_test:
needs: [ build-cache ]
Expand Down Expand Up @@ -363,7 +363,7 @@ jobs:
run: git --no-pager diff --exit-code

- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)

test_client_on_docker:
needs: [ build-cache ]
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28=
github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
Expand Down
108 changes: 46 additions & 62 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
const (
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
)
Expand Down Expand Up @@ -85,7 +86,7 @@ type AccountManager interface {
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
Expand All @@ -105,6 +106,7 @@ type AccountManager interface {
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
Expand All @@ -126,8 +128,8 @@ type AccountManager interface {
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
Expand All @@ -138,7 +140,7 @@ type AccountManager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *types.Account) (map[string]struct{}, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
Expand Down Expand Up @@ -379,14 +381,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}

if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}

updateAccountPeers := false
Expand All @@ -400,7 +402,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
account.Network.Serial++
}

err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -437,13 +439,13 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
return nil
}

func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration

am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
Expand All @@ -452,7 +454,7 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
Expand All @@ -466,33 +468,31 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextPeerExpiration()
return peerSchedulerRetryInterval, true
}

expiredPeers := account.GetExpiredPeers()
var peerIDs []string
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
}

log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)

if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}

return account.GetNextPeerExpiration()
return am.getNextPeerExpiration(ctx, accountID)
}
}

func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
}
}

Expand All @@ -502,34 +502,33 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
inactivePeers, err := am.getInactivePeers(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextInactivePeerExpiration()
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}

expiredPeers := account.GetInactivePeers()
var peerIDs []string
for _, peer := range expiredPeers {
for _, peer := range inactivePeers {
peerIDs = append(peerIDs, peer.ID)
}

log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)

if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextInactivePeerExpiration()
if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}

return account.GetNextInactivePeerExpiration()
return am.getNextInactivePeerExpiration(ctx, accountID)
}
}

// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
}
}

Expand Down Expand Up @@ -665,7 +664,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided")
}

accountID, err := am.Store.GetAccountIDByUserID(userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
Expand Down Expand Up @@ -1450,7 +1449,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}

userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
Expand Down Expand Up @@ -1497,7 +1496,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
}

func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
Expand Down Expand Up @@ -1559,17 +1558,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, status.NewGetAccountError(err)
}

peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
}

err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
Expand All @@ -1583,12 +1577,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return status.NewGetAccountError(err)
}

err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
Expand All @@ -1609,12 +1598,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}

_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return mapError(ctx, err)
}
Expand Down Expand Up @@ -1683,8 +1667,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}

func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
Expand All @@ -1695,7 +1679,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
}

if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer)
err = am.handleExpiredPeer(ctx, transaction, user, peer)
if err != nil {
return false, err
}
Expand Down
Loading

0 comments on commit 1ad2cb5

Please sign in to comment.