diff --git a/internal/api/api.go b/internal/api/api.go index 57de27c3..c9564b61 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -1,6 +1,10 @@ package api import ( + "errors" + "fmt" + "context" + "encoding/json" "net/http" "net/url" "time" @@ -11,13 +15,14 @@ import ( "github.com/eduvpn/eduvpn-common/internal/api/endpoints" "github.com/eduvpn/eduvpn-common/internal/api/profiles" "github.com/eduvpn/eduvpn-common/internal/log" + httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/wireguard" "github.com/eduvpn/eduvpn-common/types/protocol" ) type API struct { - // authf is the function to retrigger authorization - authf func(string) error + // autht is the function to retrigger authorization + autht func(string) (string, error) // authd is the function that triggers when authorization is done authd func() // oauth is the oauth object @@ -30,70 +35,89 @@ type API struct { baseWKAuthURL string } +func (a *API) BaseURL() string { + return a.baseWKURL +} + // NewAPI creates a new API object by creating an OAuth object -func NewAPI(clientID string, baseWKURL string, baseWKAuthURL string, authf func(string) error, authd func(), tokens *eudoauth.Token) (*API, error) { - ep, epauth, err := refreshEndpoints(baseWKURL, baseWKAuthURL) +// TODO: make this mess shorter +func NewAPI(ctx context.Context, clientID string, baseWKURL string, baseWKAuthURL string, autht func(string) (string, error), authd func(), tokens *eduoauth.Token) (*API, error) { + ep, epauth, err := refreshEndpoints(ctx, baseWKURL, baseWKAuthURL) if err != nil { return nil, err } + cr := customRedirect(clientID) // Construct OAuth // TODO: Support mobile redirect o := eduoauth.OAuth{ ClientID: clientID, BaseAuthorizationURL: epauth.Authorization, TokenURL: epauth.Token, + CustomRedirect: cr, RedirectPath: "/callback", } - if t != nil { - o.UpdateTokens(t) + if tokens != nil { + o.UpdateTokens(*tokens) } api := &API{ - authf: authf, + autht: autht, authd: authd, - oauthC: o, - URL: ep.API, + oauth: &o, + apiURL: ep.API, baseWKURL: baseWKURL, baseWKAuthURL: baseWKAuthURL, } - err := api.authorize(ctx) + err = api.authorize(ctx) if err != nil { return nil, err } return api, nil } -func (a *API) authorize(ctx context.Context) error { +func (a *API) authorize(ctx context.Context) (err error) { + defer func() { + if err == nil { + a.authd() + } + }() + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } scope := "config" url, err := a.oauth.AuthURL(scope) if err != nil { - return nil, err + return err } - err := a.authf(url) + uri, err := a.autht(url) if err != nil { - return nil, err + return err } - err := o.Exchange(ctx) + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) if err != nil { - return nil, err + return err } - a.authd() return nil } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - u := a.URL - u.Path = path.Join(u.Path, endpoint) + u, err := httpw.JoinURLPath(a.apiURL, endpoint) + if err != nil { + return nil, nil, err + } // TODO: Cache HTTP client? httpC := httpw.NewClient(a.oauth.NewHTTPClient()) - return httpC.Do(ctx, method, u.String(), opts) + return httpC.Do(ctx, method, u, opts) } func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - h, body, err := authorized(ctx, b, auth, method, endpoint, opts) + h, body, err := a.authorized(ctx, method, endpoint, opts) if err == nil { return h, body, nil } @@ -108,8 +132,8 @@ func (a *API) authorizedRetry(ctx context.Context, method string, endpoint strin h, body, err = a.authorized(ctx, method, endpoint, opts) } // Tokens is invalid we need to renew and authorize again - tErr := &oauth.TokensInvalidError{} - if err != nil && errors.As(err, &terr) { + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { // Mark the token as invalid and retry, so we trigger the authorization flow a.oauth.SetTokenRenew() log.Logger.Debugf("the tokens were invalid, trying again...") @@ -130,18 +154,18 @@ func (a *API) Disconnect(ctx context.Context) error { func (a *API) Info(ctx context.Context) (*profiles.Info, error) { _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) if err != nil { - return nil, err + return nil, fmt.Errorf("failed API /info: %v", err) } p := profiles.Info{} - if err = json.Unmarshal(body, &profiles); err != nil { - return errors.New("failed API /info: %v", err) + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %v", err) } return &p, nil } type ConnectData struct { Configuration string - Protocol server.Protocol + Protocol protocol.Protocol Expires time.Time } @@ -163,7 +187,6 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto acpt := []string{} uv := url.Values{ "profile_id": {prof.ID}, - "public_key": {pubkey}, } if len(protos) == 0 { @@ -181,7 +204,7 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto if pTCP && len(protos) > 1 { continue } - wgKey, err = wireguard.GenerateKey() + wgKey, err := wireguard.GenerateKey() if err != nil { return nil, err } @@ -191,7 +214,7 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto acpt = append(acpt, "application/x-wireguard-profile") case protocol.OpenVPN: // set prefer TCP - uv.Set("prefer_tcp", boolToYesNo(preferTCP)) + uv.Set("prefer_tcp", boolToYesNo(pTCP)) acpt = append(acpt, "application/x-openvpn-profile") default: return nil, errors.New("Unknown protocol supplied") @@ -202,14 +225,14 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto params := &httpw.OptionalParams{Headers: hdrs, Body: uv} h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) if err != nil { - return nil, errors.New("failed API /connect call: %v", err) + return nil, fmt.Errorf("failed API /connect call: %v", err) } // Parse expiry expH := h.Get("expires") expT, err := http.ParseTime(expH) if err != nil { - return nil, errors.New("failed parsing expiry time: %v", err) + return nil, fmt.Errorf("failed parsing expiry time: %v", err) } vpnCfg := string(body) @@ -221,41 +244,41 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto return nil, errors.New("The server sent us a WireGuard profile but the client does not accept WireGuard") } content = protocol.WireGuard - vpnCfg = ConfigAddKey(vpnCfg, *wgKey) + vpnCfg = wireguard.ConfigAddKey(vpnCfg, *wgKey) } - return &ConfigData{ - Configuration: VpnCfg, + return &ConnectData{ + Configuration: vpnCfg, Protocol: content, Expires: expT, }, nil } -func endpoints(url string) (*endpoints.Endpoints, error) { - uStr, err := httpw.JoinURLPath(a.URL, "/.well-known/vpn-user-portal") +func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) { + uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal") if err != nil { - return err + return nil, err } - httpC = httpw.NewClient() + httpC := httpw.NewClient(nil) _, body, err := httpC.Get(ctx, uStr) if err != nil { - return errors.New("failed getting server endpoints: %v", err) + return nil, fmt.Errorf("failed getting server endpoints: %v", err) } ep := endpoints.Endpoints{} if err = json.Unmarshal(body, &ep); err != nil { - return errors.New("failed getting server endpoints: %v", err) + return nil, fmt.Errorf("failed getting server endpoints: %v", err) } err = ep.Validate() if err != nil { - return err + return nil, err } return &ep, nil } -func refreshEndpoints(baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoints.List, error) { +func refreshEndpoints(ctx context.Context, baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoints.List, error) { // Get the endpoints - ep, err := endpoints(baseWKURL) + ep, err := getEndpoints(ctx, baseWKURL) if err != nil { return nil, nil, err } @@ -263,8 +286,8 @@ func refreshEndpoints(baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoi // This is a mess but we essentially have to instantiate different endpoints if the authorization base URL is different from the base portal URL // This happens with secure internet when the location is not equal to the home location var epauth *endpoints.Endpoints - if authurl != apiurl { - oep, err := endpoints(basekWKAuthURL) + if baseWKAuthURL != baseWKURL { + oep, err := getEndpoints(ctx, baseWKAuthURL) if err != nil { return nil, nil, err } @@ -272,15 +295,16 @@ func refreshEndpoints(baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoi } else { epauth = ep } - return ep, epauth, err + return &ep.API.V3, &epauth.API.V3, err } -func (api *api) RefreshEndpoints() error { - ep, epauth, err := refreshEndpoints(api.baseWKURL, api.baseWKAuthURL) +func (a *API) RefreshEndpoints(ctx context.Context) error { + ep, epauth, err := refreshEndpoints(ctx, a.baseWKURL, a.baseWKAuthURL) if err != nil { return err } - api.oauth.BaseAuthorizationURL = epauth.Authorization - api.oauth.TokenURL = epauth.Token - api.apiURL = ep.API + a.oauth.BaseAuthorizationURL = epauth.Authorization + a.oauth.TokenURL = epauth.Token + a.apiURL = ep.API + return nil } diff --git a/internal/api/profiles/profiles.go b/internal/api/profiles/profiles.go index b141037b..9db6c2b4 100644 --- a/internal/api/profiles/profiles.go +++ b/internal/api/profiles/profiles.go @@ -1,6 +1,8 @@ package profiles -import "github.com/eduvpn/eduvpn-common/types/protocol" +import ( + "github.com/eduvpn/eduvpn-common/types/protocol" +) type Profile struct { ID string `json:"profile_id"` @@ -43,18 +45,18 @@ func hasProtocol(protos []string, proto protocol.Protocol) bool { return false } -func HasOpenVPN(protos []string) bool { - return hasProtocol(protos, protocol.OpenVPN) +func (p *Profile) HasOpenVPN() bool { + return hasProtocol(p.VPNProtoList, protocol.OpenVPN) } -func HasWireGuard(protos []string) bool { - return hasProtocol(protos, protocol.OpenVPN) +func (p *Profile) HasWireGuard() bool { + return hasProtocol(p.VPNProtoList, protocol.WireGuard) } func (i Info) FilterWireGuard() Info { var ret []Profile for _, p := range i.Info.ProfileList { - if !HasOpenVPN(p.VPNProtoList) { + if !p.HasOpenVPN() { continue } } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 06548f90..b7e7cad3 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -35,7 +35,7 @@ var DiscoURL = "https://disco.eduvpn.org/v2/" func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error { // No HTTP client present, create one if discovery.httpClient == nil { - discovery.httpClient = http.NewClient() + discovery.httpClient = http.NewClient(nil) } // Get json data diff --git a/internal/http/http.go b/internal/http/http.go index 7a769f09..b266fea7 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -146,8 +146,11 @@ type Client struct { } // Returns a HTTP client with some default settings -func NewClient() *Client { - c := &http.Client{} +func NewClient(client *http.Client) *Client { + c := client + if c == nil { + c = &http.Client{} + } // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses // This is used to prevent servers from sending huge amounts of data // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems diff --git a/internal/server/base.go b/internal/server/base.go index 73e29510..888c823a 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -1,10 +1,11 @@ package server -import ( +import ( + "os" + "context" + "errors" "github.com/eduvpn/eduvpn-common/internal/api" "github.com/eduvpn/eduvpn-common/internal/api/profiles" - v2cfg "github.com/eduvpn/eduvpn-common/internal/config/v2" - "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/types/protocol" ) @@ -14,6 +15,13 @@ type Base struct { Profile *profiles.Profile } +func (b *Base) ProfileID() string { + if b.Profile == nil { + return "" + } + return b.Profile.ID +} + var InvalidProfileErr = errors.New("invalid profile") // Profiles gets the profiles for the server @@ -21,26 +29,23 @@ var InvalidProfileErr = errors.New("invalid profile") // force indicates whether or not the profiles should be fetched fresh func (b *Base) Profiles(ctx context.Context, force bool) (*profiles.Info, error) { // If we have a cached copy we only return that if force is false - if s.CachedInfo != nil && !force { - return s.CachedInfo, nil + if b.CachedInfo != nil && !force { + return b.CachedInfo, nil + } + a, err := b.API() + if err != nil { + return nil, err } // Otherwise get fresh profiles and set the cache - prfs, err := s.apiw.Info(ctx) + prfs, err := a.Info(ctx) if err != nil { return nil, err } - s.CachedInfo = prfs + b.CachedInfo = prfs return prfs, nil } -func (b *Base) Profile() (*profiles.Profile, error) { - if s.Profile == nil { - return nil, InvalidProfileErr - } - return s.Profile, nil -} - func (b *Base) API() (*api.API, error) { if b.apiw == nil { return nil, errors.New("no API object found") @@ -48,13 +53,15 @@ func (b *Base) API() (*api.API, error) { return b.apiw, nil } -func (b *Base) findProfile(ctx context.Context, prID string, wgSupport bool) (*profile.Profile, error) { +func (b *Base) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profile, error) { // Get the profiles by ignoring the cache - prfs, err := s.Profiles(ctx, false) + prfs, err := b.Profiles(ctx, false) if err != nil { return nil, err } + prID := b.ProfileID() + // No profiles available if prfs.Len() == 0 { return nil, errors.New("the server has no available profiles for your account") @@ -62,7 +69,8 @@ func (b *Base) findProfile(ctx context.Context, prID string, wgSupport bool) (*p // No WireGuard support, we have to filter the profiles that only have WireGuard if !wgSupport { - prfs = prfs.FilterWireGuard(protos) + gprof := prfs.FilterWireGuard() + prfs = &gprof } var chosenP profiles.Profile @@ -91,17 +99,13 @@ func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*api.Con if err != nil { return nil, err } - prof, err := s.Profile() - if err != nil { - return nil, err - } // find a suitable profile to connect - chosenP , err := s.findProfile(ctx, s.Profile.ID, wgSupport) + chosenP , err := b.findProfile(ctx, wgSupport) if err != nil { return nil, err } - s.Profile = chosenP + b.Profile = chosenP protos := []protocol.Protocol{protocol.OpenVPN} if wgSupport { @@ -110,12 +114,13 @@ func (b *Base) Connect(ctx context.Context, wgSupport bool, pTCP bool) (*api.Con // If the client supports WireGuard and the profile supports both protocols we remove openvpn from client support if EDUVPN_PREFER_WG is set to "1" // This also only happens if prefer TCP is set to false // TODO: remove the prefer TCP check when we have implemented proxyguard - if wgSupport && os.Getenv("EDUVPN_PREFER_WG", "0") == "1" { + if wgSupport && os.Getenv("EDUVPN_PREFER_WG") == "1" { if !pTCP && chosenP.HasWireGuard() && chosenP.HasOpenVPN() { - protos = []protocol.Protocol{Protocol.WireGuard} + protos = []protocol.Protocol{protocol.WireGuard} } } - return a.Connect(ctx, chosenP, protos, pTCP) + // SAFETY: chosenP is guaranteed to be non-nil + return a.Connect(ctx, *chosenP, protos, pTCP) } func (b *Base) Disconnect(ctx context.Context) error { diff --git a/internal/server/custom.go b/internal/server/custom.go index d68084ce..77342167 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -1,12 +1,17 @@ package server -import "github.com/eduvpn/eduvpn-common/types/server" +import ( + "context" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/eduvpn/eduvpn-common/internal/api" + httpw "github.com/eduvpn/eduvpn-common/internal/http" +) type CustomServer struct { Base } -func (cs *CustomServer) Base() Base { +func (cs *CustomServer) GetBase() Base { return cs.Base } @@ -14,9 +19,9 @@ func (cs *CustomServer) Type() server.Type { return server.TypeCustom } -func (s *Servers) AddCustom(ctx context.Context, url string, na bool) (*CustomServer, error) { +func (s *Servers) AddCustom(ctx context.Context, url string, na bool) (Server, error) { // Convert to an identifier - id, err := http.EnsureValidURL(url, true) + id, err := httpw.EnsureValidURL(url, true) if err != nil { return nil, err } @@ -24,7 +29,7 @@ func (s *Servers) AddCustom(ctx context.Context, url string, na bool) (*CustomSe var a *api.API if !na { // Authorize by creating the API object - a, err = api.NewAPI(s.clientID, url, url, s.authTrigger, s.authDone, nil) + a, err = api.NewAPI(ctx, s.clientID, id, id, s.authTrigger, s.authDone, nil) if err != nil { return nil, err } @@ -33,7 +38,7 @@ func (s *Servers) AddCustom(ctx context.Context, url string, na bool) (*CustomSe // Return the server with the API return &CustomServer{ Base: Base{ - apiw: a,, + apiw: a, }, }, nil } diff --git a/internal/server/institute.go b/internal/server/institute.go index 4056b349..25aa3480 100644 --- a/internal/server/institute.go +++ b/internal/server/institute.go @@ -1,7 +1,10 @@ package server import ( + "context" + "github.com/eduvpn/eduvpn-common/internal/api" "github.com/eduvpn/eduvpn-common/internal/discovery" + httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/types/server" ) @@ -9,24 +12,26 @@ type InstituteAccess struct { Base } -func (ia *InstituteAccess) Base() Base { - return cs.Base +func (ia *InstituteAccess) GetBase() Base { + return ia.Base } func (cs *InstituteAccess) Type() server.Type { return server.TypeInstituteAccess } +// TODO: is this needed func (cs *InstituteAccess) BaseURL() (string, error) { - a err := cs.API() + a, err := cs.API() if err != nil { return "", err } + return a.BaseURL(), nil } -func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, url string, na bool) (*Server, error) { +func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, url string, na bool) (Server, error) { // Convert to an identifier - id, err := http.EnsureValidURL(url, true) + id, err := httpw.EnsureValidURL(url, true) if err != nil { return nil, err } @@ -40,7 +45,7 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, var a *api.API if !na { // Authorize by creating the API object - a, err = api.NewAPI(s.clientID, dsrv.BaseURL, dsrv.BaseURL, s.authTrigger, s.authDone, nil) + a, err = api.NewAPI(ctx, s.clientID, dsrv.BaseURL, dsrv.BaseURL, s.authTrigger, s.authDone, nil) if err != nil { return nil, err } @@ -51,8 +56,5 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, Base: Base{ apiw: a, }, - DisplayName: dsrv.SupportContact, - SupportContact: dsrv.SupportContact, - }, nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 48b49ff0..f995e277 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -1,6 +1,7 @@ package server import ( + "context" "github.com/eduvpn/eduvpn-common/internal/api" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/types/server" @@ -12,7 +13,7 @@ type SecureInternet struct { Location string } -func (si *SecureInternet) Base() Base { +func (si *SecureInternet) GetBase() Base { return si.Base } @@ -30,7 +31,7 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org var a *api.API if !na { // Authorize by creating the API object - a, err = api.NewAPI(s.clientID, dsrv.BaseURL, dsrv.BaseURL, s.authTrigger, s.authDone, nil) + a, err = api.NewAPI(ctx, s.clientID, dsrv.BaseURL, dsrv.BaseURL, s.authTrigger, s.authDone, nil) if err != nil { return nil, err } @@ -38,7 +39,9 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org // Return the server with the API return &SecureInternet{ - apiw: a, + Base: Base{ + apiw: a, + }, HomeORGID: dOrg.OrgID, Location: dsrv.CountryCode, }, nil diff --git a/internal/server/server.go b/internal/server/server.go index f16784d2..460d8349 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,28 +2,22 @@ package server import ( "context" - "os" "github.com/eduvpn/eduvpn-common/internal/api" - "github.com/eduvpn/eduvpn-common/internal/api/profiles" - "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/types/protocol" ) type Servers struct { clientID string authDone func() - authTrigger func(string) error + authTrigger func(string) (string, error) wgSupport bool - - Current *Server } type Server interface { - Base() Base + GetBase() Base } -func NewServers(name string, autht func(string) error, authd func(), wgSupport bool) *Servers { +func NewServers(name string, autht func(string) (string, error), authd func(), wgSupport bool) *Servers { return &Servers{ clientID: name, authDone: authd, @@ -33,9 +27,11 @@ func NewServers(name string, autht func(string) error, authd func(), wgSupport b } func (s *Servers) Connect(ctx context.Context, srv Server, pTCP bool) (*api.ConnectData, error) { - return srv.Base().Connect(ctx, s.wgSupport, pTCP) + b := srv.GetBase() + return b.Connect(ctx, s.wgSupport, pTCP) } func (s *Servers) Disconnect(ctx context.Context, srv Server) error { - return srv.Base().Disconnect(ctx) + b := srv.GetBase() + return b.Disconnect(ctx) } diff --git a/internal/test/server.go b/internal/test/server.go index 841b5651..9baba55d 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -34,12 +34,14 @@ func (srv *Server) Client() (*httpw.Client, error) { certs.AddCert(root) } } - // Override the client such that it only trusts the test server cert - client := httpw.NewClient() - client.Client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{ - RootCAs: certs, + client := &http.Client{ + Transport: &http.Transport{ + TlsClientConfig: &tls.Config{ + RottCAs: certs, + }, }, } - return client, nil + // Override the client such that it only trusts the test server cert + httpC := httpw.NewClient(client) + return httpC, nil }