diff --git a/client/client.go b/client/client.go index 4b909390..40228c9a 100644 --- a/client/client.go +++ b/client/client.go @@ -7,10 +7,10 @@ import ( "context" "strings" "sync" - "time" + "github.com/jwijenbergh/eduoauth-go" "github.com/eduvpn/eduvpn-common/i18nerr" - "github.com/eduvpn/eduvpn-common/internal/config" + v2 "github.com/eduvpn/eduvpn-common/internal/config/v2" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/failover" "github.com/eduvpn/eduvpn-common/internal/fsm" @@ -98,31 +98,31 @@ func (c *Client) hasDiscovery() bool { // Client is the main struct for the VPN client. type Client struct { // The name of the client - Name string `json:"-"` + Name string - // The chosen server - Servers server.List `json:"servers"` + // The servers + Servers server.Servers // The list of servers and organizations from disco - Discovery discovery.Discovery `json:"discovery"` + Discovery discovery.Discovery // The fsm - FSM fsm.FSM `json:"-"` - - // The config - Config config.Config `json:"-"` + FSM fsm.FSM // Whether or not this client supports WireGuard - SupportsWireguard bool `json:"-"` + SupportsWireguard bool // Whether to enable debugging - Debug bool `json:"-"` + Debug bool // TokenSetter sets the tokens in the client - TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"` + TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) // TokenGetter gets the tokens from the client - TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"` + TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens + + // Starting up bool + StartingUP bool mu sync.Mutex } @@ -131,24 +131,28 @@ func (c *Client) updateTokens(srv server.Server) error { if c.TokenGetter == nil { return errors.New("no token getter defined") } - pSrv, err := c.pubCurrentServer(srv) - if err != nil { - return err - } - // shouldn't happen - if pSrv == nil { - return errors.New("public server is nil when getting tokens") - } - tokens := c.TokenGetter(*pSrv) - if tokens == nil { - return errors.New("client returned nil for tokens") - } - - server.UpdateTokens(srv, oauth.Token{ - Access: tokens.Access, - Refresh: tokens.Refresh, - ExpiredTimestamp: time.Unix(tokens.Expires, 0), - }) + // TODO: TODO: + //pSrv, err := c.pubCurrentServer(srv) + //if err != nil { + // return err + //} + //// shouldn't happen + //if pSrv == nil { + // return errors.New("public server is nil when getting tokens") + //} + //tokens := c.TokenGetter(*pSrv) + //if tokens == nil { + // return errors.New("client returned nil for tokens") + //} + + //err := server.UpdateTokens(srv, eduoauth.Token{ + // Access: tokens.Access, + // Refresh: tokens.Refresh, + // ExpiredTimestamp: time.Unix(tokens.Expires, 0), + //}) + //if err != nil { + // return err + //} return nil } @@ -157,19 +161,19 @@ func (c *Client) forwardTokens(srv server.Server) error { if c.TokenSetter == nil { return errors.New("no token setter defined") } - pSrv, err := c.pubCurrentServer(srv) - if err != nil { - return err - } - if pSrv == nil { - return errors.New("public server is nil when updating tokens") - } - o := srv.OAuth() - if o == nil { - return errors.New("oauth was nil when forwarding tokens") - } - t := o.Token() - c.TokenSetter(*pSrv, t.Public()) + // TODO: TODO: + //pSrv, err := c.pubCurrentServer(srv) + //if err != nil { + // return err + //} + //if pSrv == nil { + // return errors.New("public server is nil when updating tokens") + //} + //t, err := srv.Tokens() + //if err != nil { + // return err + //} + //c.TokenSetter(*pSrv, *t) return nil } @@ -228,16 +232,46 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Debug only if given c.Debug = debug - // Initialize the Config - c.Config.Init(directory, "state") + // set the servers + c.Servers = server.NewServers(c.Name, c.triggerAuth, c.authDone, c.SupportsWireguard, v2.V2{}) + return c, nil +} - // Try to load the previous configuration - if c.Config.Load(&c) != nil { - // This error can be safely ignored, as when the config does not load, the struct will not be filled - log.Logger.Infof("Previous configuration not found") +func (c *Client) triggerAuth(ctx context.Context, url string, wait bool) (string, error) { + // Go to chosen server if possible + // TODO: Debug log? + c.FSM.GoTransition(StateChosenServer) + // Get a reply from the client + if wait { + ck := cookie.NewWithContext(ctx) + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{ + C: &ck, + Data: url, + }) + if err != nil { + errChan <- err + } + }() + g, err := ck.Receive(errChan) + if err != nil { + return "", err + } + return g, nil } + // Otherwise do normal authorization (desktop clients) + err := c.FSM.GoTransitionRequired(StateOAuthStarted, url) + if err != nil { + return "", err + } + return "", nil +} - return c, nil +func (c *Client) authDone() { + // TODO: Should this log anything if it fails? + // unhandled transition? + c.FSM.GoTransition(StateAuthorized) } // Registering means updating the FSM to get to the initial state correctly @@ -252,23 +286,11 @@ func (c *Client) Register() error { return nil } -// SaveState saves the internal state to the config -func (c *Client) SaveState() { - log.Logger.Debugf("saving state configuration....") - // Save the config - if err := c.Config.Save(&c); err != nil { - log.Logger.Infof("failed saving state configuration: '%v'", err) - } -} - // Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. func (c *Client) Deregister() { // First of all let's transition the state machine _ = c.goTransition(StateDeregistered) - // SaveState saves the configuration - c.SaveState() - // Close the log file _ = log.Logger.Close() @@ -287,9 +309,10 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organiz } // Mark organizations as expired if we have not set an organization yet - if !c.Servers.HasSecureInternet() { - c.Discovery.MarkOrganizationsExpired() - } + // TODO: Do this with config + //if !c.Servers.HasSecureInternet() { + // c.Discovery.MarkOrganizationsExpired() + //} orgs, err = c.Discovery.Organizations(ck.Context()) if err != nil { @@ -321,29 +344,13 @@ func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err e // - The list of times where notifications should be shown // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { - // Get current expiry time - srv, err := c.Servers.Current() - if err != nil { - return nil, i18nerr.Wrap(err, "The current server could not be found when getting it for expiry") - } - b, err := srv.Base() - if err != nil { - return nil, err - } - - if b.StartTime.IsZero() { - return nil, i18nerr.New("No start time is defined for this server") - } - - bT := b.RenewButtonTime() - cT := b.CountdownTime() - nT := b.NotificationTimes() + // TODO: Implement return &srvtypes.Expiry{ - StartTime: b.StartTime.Unix(), - EndTime: b.EndTime.Unix(), - ButtonTime: bT, - CountdownTime: cT, - NotificationTimes: nT, + StartTime: 0, + EndTime: 0, + ButtonTime: 0, + CountdownTime: 0, + NotificationTimes: []int64{0}, }, nil } @@ -374,130 +381,6 @@ func (c *Client) locationCallback(ck *cookie.Cookie) error { return nil } -func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { - // get a custom redirect - cr := CustomRedirect(c.Name) - url, err := server.OAuthURL(srv, c.Name, cr) - if err != nil { - return err - } - authCodeURI := "" - if cr != "" { - errChan := make(chan error) - go func() { - err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{ - C: ck, - Data: url, - }) - if err != nil { - errChan <- err - } - }() - g, err := ck.Receive(errChan) - if err != nil { - return err - } - authCodeURI = g - } else { - err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) - if err != nil { - return err - } - } - err = server.OAuthExchange(ck.Context(), srv, authCodeURI) - if err != nil { - return err - } - return nil -} - -func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool, startup bool) error { - // location - if srv.NeedsLocation() { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no secure internet location is found. Please manually connect again", server.Name(srv)) - } - err := c.locationCallback(ck) - if err != nil { - return i18nerr.Wrap(err, "The secure internet location could not be set") - } - } - - err := c.goTransition(StateChosenServer) - if err != nil { - log.Logger.Debugf("optional chosen server transition not possible: %v", err) - } - // oauth - // TODO: This should be ck.Context() - // But needsrelogin needs a rewrite to support this properly - - // first make sure we get the most up to date tokens from the client - err = c.updateTokens(srv) - if err != nil { - log.Logger.Debugf("failed to get tokens from client: %v", err) - } - if server.NeedsRelogin(context.Background(), srv) || forceauth { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but you need to authorizate again. Please manually connect again", server.Name(srv)) - } - // mark organizations as expired if the server is a secure internet server - b, berr := srv.Base() - if berr == nil && b.Type == srvtypes.TypeSecureInternet { - c.Discovery.MarkOrganizationsExpired() - } - err := c.loginCallback(ck, srv) - if err != nil { - return i18nerr.Wrap(err, "The authorization procedure failed to complete") - } - } - err = c.goTransition(StateAuthorized) - if err != nil { - return err - } - - return nil -} - -func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server, startup bool) error { - vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) - if err != nil { - log.Logger.Warningf("failed to determine whether the current protocol is valid with error: %v", err) - return err - } - if !vp { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no valid profiles were found. Please manually connect again", server.Name(srv)) - } - vps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return i18nerr.Wrapf(err, "No suitable profiles could be found") - } - errChan := make(chan error) - go func() { - err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ - C: ck, - Data: vps.Public(), - }) - if err != nil { - errChan <- err - } - }() - pID, err := ck.Receive(errChan) - if err != nil { - return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be set", pID) - } - err = server.Profile(srv, pID) - if err != nil { - return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be obtained from the server", pID) - } - } - err = c.goTransition(StateChosenProfile) - if err != nil { - return err - } - return nil -} - // AddServer adds a server with identifier and type func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) { c.mu.Lock() @@ -506,11 +389,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current defer func() { - if err != nil { - _ = c.RemoveServer(identifier, _type) //nolint:errcheck - } else { - c.SaveState() - } + // TODO + //if err != nil { + // //_ = c.RemoveServer(identifier, _type) //nolint:errcheck + //} // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { c.FSM.GoTransition(previousState) //nolint:errcheck @@ -524,380 +406,97 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. return err } } - - if _type != srvtypes.TypeSecureInternet { - identifier, err = http.EnsureValidURL(identifier, true) - if err != nil { - return i18nerr.Wrap(err, "The identifier that was passed to the library is incorrect") - } - } - var srv server.Server switch _type { case srvtypes.TypeInstituteAccess: - dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") - if err != nil { - return i18nerr.Wrapf(err, "Could not retrieve institute access server with URL: '%s' from discovery", identifier) - } - srv, err = c.Servers.AddInstituteAccess(ck.Context(), c.Name ,dSrv) + srv, err = c.Servers.AddInstitute(ck.Context(), &c.Discovery, identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The institute access server with URL: '%s' could not be added", identifier) } case srvtypes.TypeSecureInternet: - dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) - if err != nil { - // We mark the organizations as expired because we got an error - // Note that in the docs it states that it only should happen when the Org ID doesn't exist - // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found - c.Discovery.MarkOrganizationsExpired() - return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be retrieved from discovery", identifier) - } - srv, err = c.Servers.AddSecureInternet(ck.Context(), c.Name, dOrg, dSrv) + srv, err = c.Servers.AddSecure(ck.Context(), &c.Discovery, identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be added", identifier) } case srvtypes.TypeCustom: - srv, err = c.Servers.AddCustom(ck.Context(), c.Name, identifier) + srv, err = c.Servers.AddCustom(ck.Context(), identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The custom server with URL: '%s' could not be added", identifier) } default: return i18nerr.NewInternalf("Server type: '%v' is not valid to be added", _type) } - - // if we are non interactive, we run no callbacks - if ni { - return nil - } - - // callbacks - err = c.callbacks(ck, srv, false, false) - // error is already UI wrapped - if err != nil { - return err - } - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after adding: %v", terr) + // OAuth is filled if we are interactive mode, make sure the client knows it + if !ni { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after adding: %v", terr) + } } return nil } -func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool, startup bool) (cfg *srvtypes.Configuration, err error) { - // do the callbacks to ensure valid profile, location and authorization - err = c.callbacks(ck, srv, forceAuth, startup) - if err != nil { - return nil, err - } - - err = c.goTransition(StateRequestConfig) - if err != nil { - return nil, err - } - - err = c.profileCallback(ck, srv, startup) - if err != nil { - return nil, err - } - - cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) - if err != nil { - return nil, i18nerr.Wrap(err, "The VPN configuration could not be obtained") - } - p, err := server.CurrentProfile(srv) - if err != nil { - return nil, i18nerr.Wrap(err, "The current profile could not be found") - } - pcfg := cfgS.Public(p.DefaultGateway) - return &pcfg, nil -} - -func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { - switch _type { - case srvtypes.TypeInstituteAccess: - srv, err = c.Servers.InstituteAccess(identifier) - setter = c.Servers.SetInstituteAccess - case srvtypes.TypeSecureInternet: - srv, err = c.Servers.SecureInternet(identifier) - setter = c.Servers.SetSecureInternet - case srvtypes.TypeCustom: - srv, err = c.Servers.CustomServer(identifier) - setter = c.Servers.SetCustom - default: - return nil, nil, i18nerr.NewInternalf("Not a valid server type: %v", _type) - } - return srv, setter, err -} - // GetConfig gets a VPN configuration -func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (cfg *srvtypes.Configuration, err error) { +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (*srvtypes.Configuration, error) { c.mu.Lock() defer c.mu.Unlock() previousState := c.FSM.Current + var err error + defer func() { if err == nil { c.FSM.GoTransition(StateGotConfig) //nolint:errcheck - c.SaveState() } else if !c.FSM.InState(previousState) { // go back to the previous state if an error occurred c.FSM.GoTransition(previousState) //nolint:errcheck } }() - if _type != srvtypes.TypeSecureInternet { - identifier, err = http.EnsureValidURL(identifier, true) - if err != nil { - return nil, i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid", identifier, _type) - } - } - err = c.goTransition(StateLoadingServer) - if err != nil { - return nil, err - } - srv, set, err := c.server(identifier, _type) - if err != nil { - return nil, err - } - // refresh the server endpoints - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - - // If we get a canceled error, return that, otherwise just log the error - if err != nil { - if errors.Is(err, context.Canceled) { - return nil, i18nerr.Wrap(err, "The operation for getting a VPN configuration was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - // get a config and retry with authorization if expired - cfg, err = c.config(ck, srv, pTCP, false, startup) - tErr := &oauth.TokensInvalidError{} - if err != nil && errors.As(err, &tErr) { - log.Logger.Debugf("the tokens were invalid, trying again...") - cfg, err = c.config(ck, srv, pTCP, true, startup) - } - - // tokens might be updated, forward them - defer func() { - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after get config: %v", terr) - } - }() - // still an error, return nil with the error + err = c.goTransition(StateLoadingServer) + // this is already wrapped in an UI error if err != nil { return nil, err } - // set the current server - if err = set(srv); err != nil { - return nil, i18nerr.Wrapf(err, "Failed to set the server with identifier: '%s' as the current", identifier) - } - - return cfg, nil -} - -func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { - if _type != srvtypes.TypeSecureInternet { - identifier, err = http.EnsureValidURL(identifier, true) - if err != nil { - return i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid for removal", identifier, _type) - } - } - // miscellaneous error - var mErr error + var srv server.Server switch _type { - case srvtypes.TypeInstituteAccess: - mErr = c.Servers.RemoveInstituteAccess(identifier) - case srvtypes.TypeSecureInternet: - mErr = c.Servers.RemoveSecureInternet(identifier) case srvtypes.TypeCustom: - mErr = c.Servers.RemoveCustom(identifier) + // TODO: Token passing + srv, err = c.Servers.GetCustom(ck.Context(), identifier, eduoauth.Token{}) default: - return i18nerr.NewInternalf("Not a valid server type: %v", _type) - } - if mErr != nil { - log.Logger.Debugf("failed to remove server with identifier: '%s' and type: '%d', error: %v", identifier, _type, mErr) + panic("unreachable") } - c.SaveState() - return nil -} - -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srv, err := c.Servers.Current() if err != nil { + // TODO: wrap error return nil, err } - return c.pubCurrentServer(srv) -} -func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { - b, err := srv.Base() + cfg, err := c.Servers.Connect(ck.Context(), srv, pTCP) if err != nil { return nil, err } - pub, err := srv.Public() - if err != nil { - return nil, err - } - switch t := pub.(type) { - case *srvtypes.Server: - if b.Type == srvtypes.TypeInstituteAccess { - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *t, - SupportContacts: b.SupportContact, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - } - return &srvtypes.Current{ - Custom: t, - Type: srvtypes.TypeCustom, - }, nil - case *srvtypes.SecureInternet: - t.SupportContacts = b.SupportContact - t.Locations = c.Discovery.SecureLocationList() - return &srvtypes.Current{ - SecureInternet: t, - Type: srvtypes.TypeSecureInternet, - }, nil - default: - panic("unknown type") - } + return cfg, nil } -// TODO: This should not rely on interface{} -func (c *Client) pubServer(srv server.Server) (interface{}, error) { - pub, err := srv.Public() - if err != nil { - return nil, err - } - b, err := srv.Base() - if err != nil { - return nil, err - } - switch t := pub.(type) { - case *srvtypes.Server: - if b.Type == srvtypes.TypeInstituteAccess { - return &srvtypes.Institute{ - Server: *t, - SupportContacts: b.SupportContact, - // TODO: delisted - Delisted: false, - }, nil - } - return t, nil - case *srvtypes.SecureInternet: - t.SupportContacts = b.SupportContact - t.Locations = c.Discovery.SecureLocationList() - return t, nil - default: - panic("unknown type") - } +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + panic("TODO") + return nil } -func (c *Client) ServerList() (*srvtypes.List, error) { - if c.FSM.InState(StateDeregistered) { - return nil, i18nerr.NewInternal("Client is not registered") - } - var customServers []srvtypes.Server - for _, v := range c.Servers.CustomServers.Map { - if v == nil { - continue - } - p, err := c.pubServer(v) - if err != nil { - continue - } - c, ok := p.(*srvtypes.Server) - if !ok { - continue - } - customServers = append(customServers, *c) - } - var instituteServers []srvtypes.Institute - for _, v := range c.Servers.InstituteServers.Map { - if v == nil { - continue - } - p, err := c.pubServer(v) - if err != nil { - continue - } - i, ok := p.(*srvtypes.Institute) - if !ok { - continue - } - instituteServers = append(instituteServers, *i) - } - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - srv := &c.Servers.SecureInternetHomeServer - p, err := c.pubServer(srv) - if err == nil { - s, ok := p.(*srvtypes.SecureInternet) - if ok { - secureInternet = s - } - } - } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, - }, nil +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + panic("TODO") + return nil, errors.New("unreachable") } func (c *Client) SetProfileID(pID string) (err error) { - srv, err := c.Servers.Current() - if err != nil { - return err - } - err = server.Profile(srv, pID) - if err == nil { - c.SaveState() - } - return err + panic("TODO") + return errors.New("unreachable") } func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { - // get the current server - srv, err := c.Servers.Current() - if err != nil { - return i18nerr.Wrap(err, "Failed to get the current server to cleanup the connection") - } - - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - - // If we get a canceled error, return that, otherwise just log the error - if err != nil { - if errors.Is(err, context.Canceled) { - return i18nerr.Wrap(err, "The cleanup process was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - - defer c.SaveState() - err = c.updateTokens(srv) - if err != nil { - log.Logger.Debugf("failed to update tokens for disconnect: %v", err) - } - err = server.Disconnect(ck.Context(), srv) - if err != nil { - return i18nerr.Wrap(err, "Failed to cleanup the VPN connection for the current server") - } - err = c.forwardTokens(srv) - if err != nil { - log.Logger.Debugf("failed to forward tokens after disconnect: %v", err) - } - return nil + panic("TODO") + return errors.New("unreachable") } func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { @@ -906,58 +505,16 @@ func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err e return i18nerr.NewInternal("Setting a secure internet location with this client ID is not supported") } - if !c.Servers.HasSecureInternet() { - return i18nerr.Newf("No secure internet server available to set a location for") - } - - dSrv, err := c.Discovery.ServerByCountryCode(countryCode) - if err != nil { - return err - } - - err = c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) - if err == nil { - c.SaveState() - } - return err + //if !c.Servers.HasSecureInternet() { + // return i18nerr.Newf("No secure internet server available to set a location for") + //} + panic("TODO") + return errors.New("unreachable") } func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { - c.mu.Lock() - defer c.mu.Unlock() - srv, err := c.Servers.Current() - if err != nil { - return i18nerr.Wrap(err, "Failed to get current server for renewing the session") - } - // The server has not been chosen yet, this means that we want to manually renew - // TODO: is this needed? - if !c.FSM.InState(StateLoadingServer) { - c.FSM.GoTransition(StateLoadingServer) //nolint:errcheck - } - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - - // If we get a canceled error, return that, otherwise just log the error - if err != nil { - if errors.Is(err, context.Canceled) { - return i18nerr.Wrap(err, "The renewing process was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - - // update tokens in the end - defer func() { - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after renew: %v", terr) - } - }() - defer c.SaveState() - // TODO: Maybe this can be deleted because we force auth now - server.MarkTokensForRenew(srv) - // run the callbacks by forcing auth - return c.callbacks(ck, srv, true, false) + panic("TODO") + return errors.New("unreachable") } func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readRxBytes func() (int64, error)) (bool, error) { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 7f726064..9234b328 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -176,6 +176,7 @@ func printConfig(url string, srvType srvtypes.Type) { fmt.Fprintf(os.Stderr, "failed getting a config: %v\n", err) return } + fmt.Println(cfg.Protocol) fmt.Println("Obtained config:", cfg.VPNConfig) } diff --git a/exports/exports.go b/exports/exports.go index 919b9c8f..b3c8af06 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -457,19 +457,21 @@ func CurrentServer() (*C.char, *C.char) { // //export ServerList func ServerList() (*C.char, *C.char) { - state, stateErr := getVPNState() + _, stateErr := getVPNState() if stateErr != nil { return nil, getCError(stateErr) } - list, err := state.ServerList() - if err != nil { - return nil, getCError(err) - } - ret, err := getReturnData(list) - if err != nil { - return nil, getCError(err) - } - return C.CString(ret), nil + panic("TODO") + return nil, nil + //list, err := state.ServerList() + //if err != nil { + // return nil, getCError(err) + //} + //ret, err := getReturnData(list) + //if err != nil { + // return nil, getCError(err) + //} + //return C.CString(ret), nil } // GetConfig gets a configuration for the server @@ -839,7 +841,8 @@ func SetSupportWireguard(support C.int) *C.char { if stateErr != nil { return getCError(stateErr) } - state.SupportsWireguard = support != 0 + // TODO: Do not do any nested struct member here + state.Servers.WGSupport = support != 0 return nil } diff --git a/internal/api/api.go b/internal/api/api.go index c9564b61..9c55ab77 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -22,7 +22,7 @@ import ( type API struct { // autht is the function to retrigger authorization - autht func(string) (string, error) + autht func(context.Context, string, bool) (string, error) // authd is the function that triggers when authorization is done authd func() // oauth is the oauth object @@ -39,9 +39,25 @@ func (a *API) BaseURL() string { return a.baseWKURL } +func (a *API) UpdateTokens(tok eduoauth.Token) error { + if a.oauth == nil { + return errors.New("no oauth object defined") + } + a.oauth.UpdateTokens(tok) + return nil +} + +func (a *API) Tokens() (*eduoauth.Token, error) { + if a.oauth == nil { + return nil, errors.New("no oauth object defined") + } + t := a.oauth.Token() + return &t, nil +} + // NewAPI creates a new API object by creating an OAuth object // TODO: make this mess shorter -func NewAPI(ctx context.Context, clientID string, baseWKURL string, baseWKAuthURL string, autht func(string) (string, error), authd func(), tokens *eduoauth.Token) (*API, error) { +func NewAPI(ctx context.Context, clientID string, baseWKURL string, baseWKAuthURL string, autht func(context.Context, string, bool) (string, error), authd func(), tokens *eduoauth.Token) (*API, error) { ep, epauth, err := refreshEndpoints(ctx, baseWKURL, baseWKAuthURL) if err != nil { return nil, err @@ -93,7 +109,8 @@ func (a *API) authorize(ctx context.Context) (err error) { if err != nil { return err } - uri, err := a.autht(url) + // We expect an uri if custom redirect is non empty + uri, err := a.autht(ctx, url, a.oauth.CustomRedirect != "") if err != nil { return err } @@ -106,10 +123,8 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - u, err := httpw.JoinURLPath(a.apiURL, endpoint) - if err != nil { - return nil, nil, err - } + u := a.apiURL + endpoint + fmt.Println("GOT API URL", u) // TODO: Cache HTTP client? httpC := httpw.NewClient(a.oauth.NewHTTPClient()) @@ -204,10 +219,11 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto if pTCP && len(protos) > 1 { continue } - wgKey, err := wireguard.GenerateKey() + gk, err := wireguard.GenerateKey() if err != nil { return nil, err } + wgKey = &gk // Set the public key pubkey := wgKey.PublicKey() uv.Set("public_key", pubkey.String()) diff --git a/internal/config/v2/v2.go b/internal/config/v2/v2.go index db5b5dc9..77f480bf 100644 --- a/internal/config/v2/v2.go +++ b/internal/config/v2/v2.go @@ -2,6 +2,7 @@ package v2 import ( "encoding/json" + "fmt" "errors" "time" @@ -54,18 +55,20 @@ func (v2 *V2) AddCustom(url string) *Server { // Otherwise add to the list and return v := Server{} cst.List[url] = v + cst.LastChosenID = url + v2.Servers.Custom = cst return &v } -func (v2 *V2) GetCustom(url string) *Server { +func (v2 *V2) GetCustom(url string) (*Server, error) { cst := v2.Servers.Custom if len(cst.List) == 0 { - cst.List = make(map[string]Server) + return nil, fmt.Errorf("server list is empty, no custom server with url: '%s'", url) } - // Otherwise add to the list and return - v := Server{} - cst.List[url] = v - return &v + if v, ok := cst.List[url]; ok { + return &v, nil + } + return nil, errors.New("custom server with url: '%s' does not exist") } type Joined struct { diff --git a/internal/server/base.go b/internal/server/base.go index 888c823a..ef455ad1 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -1,25 +1,40 @@ package server -import ( - "os" +import ( "context" "errors" + "os" + "github.com/eduvpn/eduvpn-common/internal/api" "github.com/eduvpn/eduvpn-common/internal/api/profiles" "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" ) type Base struct { apiw *api.API CachedInfo *profiles.Info - Profile *profiles.Profile + prID string } -func (b *Base) ProfileID() string { - if b.Profile == nil { - return "" +func (b *Base) UpdateTokens(tok eduoauth.Token) error { + if b.apiw == nil { + return errors.New("no API object defined") } - return b.Profile.ID + return b.apiw.UpdateTokens(tok) +} + +func (b *Base) Tokens() (*srvtypes.Tokens, error) { + tok, err := b.apiw.Tokens() + if err != nil { + return nil, err + } + return &srvtypes.Tokens{ + Access: tok.Access, + Refresh: tok.Refresh, + Expires: tok.ExpiredTimestamp.Unix(), + }, nil } var InvalidProfileErr = errors.New("invalid profile") @@ -60,8 +75,6 @@ func (b *Base) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profi return nil, err } - prID := b.ProfileID() - // No profiles available if prfs.Len() == 0 { return nil, errors.New("the server has no available profiles for your account") @@ -85,7 +98,7 @@ func (b *Base) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profi chosenP = prfs.MustIndex(0) default: // Profile doesn't exist - v := prfs.Get(prID) + v := prfs.Get(b.prID) if v == nil { return nil, InvalidProfileErr } @@ -94,7 +107,7 @@ func (b *Base) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profi return &chosenP, nil } -func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*api.ConnectData, error) { +func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*srvtypes.Configuration, error) { a, err := b.API() if err != nil { return nil, err @@ -105,7 +118,7 @@ func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*api.Con if err != nil { return nil, err } - b.Profile = chosenP + b.prID = chosenP.ID protos := []protocol.Protocol{protocol.OpenVPN} if wgSupport { @@ -120,7 +133,16 @@ func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*api.Con } } // SAFETY: chosenP is guaranteed to be non-nil - return a.Connect(ctx, *chosenP, protos, pTCP) + apicfg, err := a.Connect(ctx, *chosenP, protos, pTCP) + if err != nil { + return nil, err + } + // TODO: Save connection + return &srvtypes.Configuration{ + VPNConfig: apicfg.Configuration, + Protocol: apicfg.Protocol, + DefaultGateway: chosenP.DefaultGateway, + }, nil } func (b *Base) Disconnect(ctx context.Context) error { diff --git a/internal/server/custom.go b/internal/server/custom.go index 77342167..bca3c08b 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -2,9 +2,11 @@ package server import ( "context" - "github.com/eduvpn/eduvpn-common/types/server" + "github.com/eduvpn/eduvpn-common/internal/api" httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" ) type CustomServer struct { @@ -35,10 +37,38 @@ func (s *Servers) AddCustom(ctx context.Context, url string, na bool) (Server, e } } + cust := CustomServer{ + Base: Base{ + apiw: a, + }, + } + + s.config.AddCustom(id) + // Return the server with the API - return &CustomServer{ + return &cust, nil +} + +func (s *Servers) GetCustom(ctx context.Context, url string, tok eduoauth.Token) (Server, error) { + // Convert to an identifier + id, err := httpw.EnsureValidURL(url, true) + if err != nil { + return nil, err + } + // Get the server from the config + srv, err := s.config.GetCustom(id) + if err != nil { + return nil, err + } + a, err := api.NewAPI(ctx, s.clientID, id, id, s.authTrigger, s.authDone, &tok) + if err != nil { + return nil, err + } + cust := &CustomServer{ Base: Base{ apiw: a, + prID: srv.ProfileID, }, - }, nil + } + return cust, nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index f995e277..e40da502 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -25,6 +25,10 @@ func (si *SecureInternet) Type() server.Type { func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, na bool) (*SecureInternet, error) { dOrg, dsrv, err := disco.SecureHomeArgs(orgID) if err != nil { + // We mark the organizations as expired because we got an error + // Note that in the docs it states that it only should happen when the Org ID doesn't exist + // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found + disco.MarkOrganizationsExpired() return nil, err } diff --git a/internal/server/server.go b/internal/server/server.go index 460d8349..a04faf01 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,32 +3,37 @@ package server import ( "context" - "github.com/eduvpn/eduvpn-common/internal/api" + v2 "github.com/eduvpn/eduvpn-common/internal/config/v2" + "github.com/jwijenbergh/eduoauth-go" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" ) type Servers struct { clientID string authDone func() - authTrigger func(string) (string, error) - wgSupport bool + authTrigger func(context.Context, string, bool) (string, error) + WGSupport bool + config v2.V2 } type Server interface { GetBase() Base + UpdateTokens(tok eduoauth.Token) error } -func NewServers(name string, autht func(string) (string, error), authd func(), wgSupport bool) *Servers { - return &Servers{ +func NewServers(name string, autht func(context.Context, string, bool) (string, error), authd func(), wgSupport bool, cfg v2.V2) Servers { + return Servers{ clientID: name, authDone: authd, authTrigger: autht, - wgSupport: wgSupport, + WGSupport: wgSupport, + config: cfg, } } -func (s *Servers) Connect(ctx context.Context, srv Server, pTCP bool) (*api.ConnectData, error) { +func (s *Servers) Connect(ctx context.Context, srv Server, pTCP bool) (*srvtypes.Configuration, error) { b := srv.GetBase() - return b.Connect(ctx, s.wgSupport, pTCP) + return b.Connect(ctx, s.WGSupport, pTCP) } func (s *Servers) Disconnect(ctx context.Context, srv Server) error { diff --git a/types/server/server.go b/types/server/server.go index 4198e86c..b0aeb793 100644 --- a/types/server/server.go +++ b/types/server/server.go @@ -106,7 +106,7 @@ type Tokens struct { // Refresh is the refresh token Refresh string `json:"refresh_token"` // Expires is the Unix timestamp when the token expires - Expires int64 `json:"expires_in"` + Expires int64 `json:"expires"` } // Server is the basic type for a server. This is the base for secure internet and institute access. Custom servers are equal to this type