Skip to content
This repository has been archived by the owner on Nov 29, 2024. It is now read-only.

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jwijenbergh committed Nov 10, 2023
1 parent eea9f50 commit cd2b42f
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 122 deletions.
128 changes: 76 additions & 52 deletions internal/api/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package api

import (
"errors"
"fmt"
"context"
"encoding/json"
"net/http"
"net/url"
"time"
Expand All @@ -11,13 +15,14 @@ import (
"github.com/eduvpn/eduvpn-common/internal/api/endpoints"
"github.com/eduvpn/eduvpn-common/internal/api/profiles"
"github.com/eduvpn/eduvpn-common/internal/log"
httpw "github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
"github.com/eduvpn/eduvpn-common/types/protocol"
)

type API struct {
// authf is the function to retrigger authorization
authf func(string) error
// autht is the function to retrigger authorization
autht func(string) (string, error)
// authd is the function that triggers when authorization is done
authd func()
// oauth is the oauth object
Expand All @@ -30,70 +35,89 @@ type API struct {
baseWKAuthURL string
}

func (a *API) BaseURL() string {
return a.baseWKURL
}

// NewAPI creates a new API object by creating an OAuth object
func NewAPI(clientID string, baseWKURL string, baseWKAuthURL string, authf func(string) error, authd func(), tokens *eudoauth.Token) (*API, error) {
ep, epauth, err := refreshEndpoints(baseWKURL, baseWKAuthURL)
// TODO: make this mess shorter
func NewAPI(ctx context.Context, clientID string, baseWKURL string, baseWKAuthURL string, autht func(string) (string, error), authd func(), tokens *eduoauth.Token) (*API, error) {

Check failure on line 44 in internal/api/api.go

View workflow job for this annotation

GitHub Actions / Lint go

undefined: eduoauth (typecheck)
ep, epauth, err := refreshEndpoints(ctx, baseWKURL, baseWKAuthURL)
if err != nil {
return nil, err
}

cr := customRedirect(clientID)

Check failure on line 50 in internal/api/api.go

View workflow job for this annotation

GitHub Actions / Lint go

undefined: customRedirect (typecheck)
// Construct OAuth
// TODO: Support mobile redirect
o := eduoauth.OAuth{

Check failure on line 53 in internal/api/api.go

View workflow job for this annotation

GitHub Actions / Lint go

undefined: eduoauth (typecheck)
ClientID: clientID,
BaseAuthorizationURL: epauth.Authorization,
TokenURL: epauth.Token,
CustomRedirect: cr,
RedirectPath: "/callback",
}

if t != nil {
o.UpdateTokens(t)
if tokens != nil {
o.UpdateTokens(*tokens)
}

api := &API{
authf: authf,
autht: autht,
authd: authd,
oauthC: o,
URL: ep.API,
oauth: &o,
apiURL: ep.API,
baseWKURL: baseWKURL,
baseWKAuthURL: baseWKAuthURL,
}
err := api.authorize(ctx)
err = api.authorize(ctx)
if err != nil {
return nil, err
}
return api, nil
}

func (a *API) authorize(ctx context.Context) error {
func (a *API) authorize(ctx context.Context) (err error) {
defer func() {
if err == nil {
a.authd()
}
}()
_, err = a.oauth.AccessToken(ctx)
// already authorized
if err == nil {
return nil
}
scope := "config"
url, err := a.oauth.AuthURL(scope)
if err != nil {
return nil, err
return err
}
err := a.authf(url)
uri, err := a.autht(url)
if err != nil {
return nil, err
return err
}
err := o.Exchange(ctx)
// The uri is only given here if a custom redirect is done
err = a.oauth.Exchange(ctx, uri)
if err != nil {
return nil, err
return err
}
a.authd()
return nil
}

func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
u := a.URL
u.Path = path.Join(u.Path, endpoint)
u, err := httpw.JoinURLPath(a.apiURL, endpoint)
if err != nil {
return nil, nil, err
}

// TODO: Cache HTTP client?
httpC := httpw.NewClient(a.oauth.NewHTTPClient())
return httpC.Do(ctx, method, u.String(), opts)
return httpC.Do(ctx, method, u, opts)
}

func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
h, body, err := authorized(ctx, b, auth, method, endpoint, opts)
h, body, err := a.authorized(ctx, method, endpoint, opts)
if err == nil {
return h, body, nil
}
Expand All @@ -108,8 +132,8 @@ func (a *API) authorizedRetry(ctx context.Context, method string, endpoint strin
h, body, err = a.authorized(ctx, method, endpoint, opts)
}
// Tokens is invalid we need to renew and authorize again
tErr := &oauth.TokensInvalidError{}
if err != nil && errors.As(err, &terr) {
tErr := &eduoauth.TokensInvalidError{}
if err != nil && errors.As(err, &tErr) {
// Mark the token as invalid and retry, so we trigger the authorization flow
a.oauth.SetTokenRenew()
log.Logger.Debugf("the tokens were invalid, trying again...")
Expand All @@ -130,18 +154,18 @@ func (a *API) Disconnect(ctx context.Context) error {
func (a *API) Info(ctx context.Context) (*profiles.Info, error) {
_, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed API /info: %v", err)
}
p := profiles.Info{}
if err = json.Unmarshal(body, &profiles); err != nil {
return errors.New("failed API /info: %v", err)
if err = json.Unmarshal(body, &p); err != nil {
return nil, fmt.Errorf("failed API /info: %v", err)
}
return &p, nil
}

type ConnectData struct {
Configuration string
Protocol server.Protocol
Protocol protocol.Protocol
Expires time.Time
}

Expand All @@ -163,7 +187,6 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
acpt := []string{}
uv := url.Values{
"profile_id": {prof.ID},
"public_key": {pubkey},
}

if len(protos) == 0 {
Expand All @@ -181,7 +204,7 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
if pTCP && len(protos) > 1 {
continue
}
wgKey, err = wireguard.GenerateKey()
wgKey, err := wireguard.GenerateKey()
if err != nil {
return nil, err
}
Expand All @@ -191,7 +214,7 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
acpt = append(acpt, "application/x-wireguard-profile")
case protocol.OpenVPN:
// set prefer TCP
uv.Set("prefer_tcp", boolToYesNo(preferTCP))
uv.Set("prefer_tcp", boolToYesNo(pTCP))
acpt = append(acpt, "application/x-openvpn-profile")
default:
return nil, errors.New("Unknown protocol supplied")
Expand All @@ -202,14 +225,14 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
params := &httpw.OptionalParams{Headers: hdrs, Body: uv}
h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params)
if err != nil {
return nil, errors.New("failed API /connect call: %v", err)
return nil, fmt.Errorf("failed API /connect call: %v", err)
}

// Parse expiry
expH := h.Get("expires")
expT, err := http.ParseTime(expH)
if err != nil {
return nil, errors.New("failed parsing expiry time: %v", err)
return nil, fmt.Errorf("failed parsing expiry time: %v", err)
}

vpnCfg := string(body)
Expand All @@ -221,66 +244,67 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
return nil, errors.New("The server sent us a WireGuard profile but the client does not accept WireGuard")
}
content = protocol.WireGuard
vpnCfg = ConfigAddKey(vpnCfg, *wgKey)
vpnCfg = wireguard.ConfigAddKey(vpnCfg, *wgKey)
}

return &ConfigData{
Configuration: VpnCfg,
return &ConnectData{
Configuration: vpnCfg,
Protocol: content,
Expires: expT,
}, nil
}

func endpoints(url string) (*endpoints.Endpoints, error) {
uStr, err := httpw.JoinURLPath(a.URL, "/.well-known/vpn-user-portal")
func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) {
uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal")
if err != nil {
return err
return nil, err
}
httpC = httpw.NewClient()
httpC := httpw.NewClient(nil)
_, body, err := httpC.Get(ctx, uStr)
if err != nil {
return errors.New("failed getting server endpoints: %v", err)
return nil, fmt.Errorf("failed getting server endpoints: %v", err)
}

ep := endpoints.Endpoints{}
if err = json.Unmarshal(body, &ep); err != nil {
return errors.New("failed getting server endpoints: %v", err)
return nil, fmt.Errorf("failed getting server endpoints: %v", err)
}
err = ep.Validate()
if err != nil {
return err
return nil, err
}
return &ep, nil
}

func refreshEndpoints(baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoints.List, error) {
func refreshEndpoints(ctx context.Context, baseWKURL, baseWKAuthURL string) (*endpoints.List, *endpoints.List, error) {
// Get the endpoints
ep, err := endpoints(baseWKURL)
ep, err := getEndpoints(ctx, baseWKURL)
if err != nil {
return nil, nil, err
}

// This is a mess but we essentially have to instantiate different endpoints if the authorization base URL is different from the base portal URL
// This happens with secure internet when the location is not equal to the home location
var epauth *endpoints.Endpoints
if authurl != apiurl {
oep, err := endpoints(basekWKAuthURL)
if baseWKAuthURL != baseWKURL {
oep, err := getEndpoints(ctx, baseWKAuthURL)
if err != nil {
return nil, nil, err
}
epauth = oep
} else {
epauth = ep
}
return ep, epauth, err
return &ep.API.V3, &epauth.API.V3, err
}

func (api *api) RefreshEndpoints() error {
ep, epauth, err := refreshEndpoints(api.baseWKURL, api.baseWKAuthURL)
func (a *API) RefreshEndpoints(ctx context.Context) error {
ep, epauth, err := refreshEndpoints(ctx, a.baseWKURL, a.baseWKAuthURL)
if err != nil {
return err
}
api.oauth.BaseAuthorizationURL = epauth.Authorization
api.oauth.TokenURL = epauth.Token
api.apiURL = ep.API
a.oauth.BaseAuthorizationURL = epauth.Authorization
a.oauth.TokenURL = epauth.Token
a.apiURL = ep.API
return nil
}
14 changes: 8 additions & 6 deletions internal/api/profiles/profiles.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package profiles

import "github.com/eduvpn/eduvpn-common/types/protocol"
import (
"github.com/eduvpn/eduvpn-common/types/protocol"
)

type Profile struct {
ID string `json:"profile_id"`
Expand Down Expand Up @@ -43,18 +45,18 @@ func hasProtocol(protos []string, proto protocol.Protocol) bool {
return false
}

func HasOpenVPN(protos []string) bool {
return hasProtocol(protos, protocol.OpenVPN)
func (p *Profile) HasOpenVPN() bool {
return hasProtocol(p.VPNProtoList, protocol.OpenVPN)
}

func HasWireGuard(protos []string) bool {
return hasProtocol(protos, protocol.OpenVPN)
func (p *Profile) HasWireGuard() bool {
return hasProtocol(p.VPNProtoList, protocol.WireGuard)
}

func (i Info) FilterWireGuard() Info {
var ret []Profile
for _, p := range i.Info.ProfileList {
if !HasOpenVPN(p.VPNProtoList) {
if !p.HasOpenVPN() {
continue
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var DiscoURL = "https://disco.eduvpn.org/v2/"
func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error {
// No HTTP client present, create one
if discovery.httpClient == nil {
discovery.httpClient = http.NewClient()
discovery.httpClient = http.NewClient(nil)
}

// Get json data
Expand Down
7 changes: 5 additions & 2 deletions internal/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ type Client struct {
}

// Returns a HTTP client with some default settings
func NewClient() *Client {
c := &http.Client{}
func NewClient(client *http.Client) *Client {
c := client
if c == nil {
c = &http.Client{}
}
// ReadLimit denotes the maximum amount of bytes that are read in HTTP responses
// This is used to prevent servers from sending huge amounts of data
// A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems
Expand Down
Loading

0 comments on commit cd2b42f

Please sign in to comment.