diff --git a/internal/entities/trakt.go b/internal/entities/trakt.go index afa2735..53dd3a6 100644 --- a/internal/entities/trakt.go +++ b/internal/entities/trakt.go @@ -123,3 +123,12 @@ type TraktList struct { ListItems TraktItems IsWatchlist bool } + +type TraktUserInfo struct { + Username string `json:"username"` + Private bool `json:"private"` + Name string `json:"name"` + Vip bool `json:"vip"` + VipEp bool `json:"vip_ep"` + IDMeta TraktIDMeta `json:"ids"` +} diff --git a/pkg/client/client.go b/pkg/client/client.go index 4bf2e25..f91d761 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -42,6 +42,7 @@ type TraktClientInterface interface { HistoryGet(itemType, itemID string) (entities.TraktItems, error) HistoryAdd(items entities.TraktItems) error HistoryRemove(items entities.TraktItems) error + UserInfoGet() (*entities.TraktUserInfo, error) } type requestFields struct { @@ -77,6 +78,18 @@ func (r reusableReader) Read(p []byte) (int, error) { return n, err } +func selectorExists(body io.ReadCloser, selector string) error { + defer body.Close() + doc, err := goquery.NewDocumentFromReader(body) + if err != nil { + return fmt.Errorf("failure creating goquery document from response: %w", err) + } + if doc.Find(selector).Length() == 0 { + return fmt.Errorf("failure finding selector %s", selector) + } + return nil +} + func selectorAttributeScrape(body io.ReadCloser, selector, attribute string) (*string, error) { defer body.Close() doc, err := goquery.NewDocumentFromReader(body) diff --git a/pkg/client/trakt.go b/pkg/client/trakt.go index 1379081..25385bb 100644 --- a/pkg/client/trakt.go +++ b/pkg/client/trakt.go @@ -46,6 +46,7 @@ const ( traktPathHistoryRemove = "/sync/history/remove" traktPathRatings = "/sync/ratings" traktPathRatingsRemove = "/sync/ratings/remove" + traktPathUserInfo = "/users/me" traktPathUserList = "/users/%s/lists/%s" traktPathUserListItems = "/users/%s/lists/%s/items" traktPathUserListItemsRemove = "/users/%s/lists/%s/items/remove" @@ -115,6 +116,11 @@ func (tc *TraktClient) hydrate() error { return fmt.Errorf("failure exchanging trakt device code for access token: %w", err) } tc.config.accessToken = authTokens.AccessToken + userInfo, err := tc.UserInfoGet() + if err != nil { + return fmt.Errorf("failure getting trakt user info: %w", err) + } + tc.config.username = userInfo.Username return nil } @@ -208,16 +214,7 @@ func (tc *TraktClient) ActivateAuthorize(authenticityToken string) error { if err != nil { return err } - value, err := selectorAttributeScrape(response.Body, "a.visible-xs", "href") - if err != nil { - return err - } - hrefPieces := strings.Split(*value, "/") - if len(hrefPieces) != 3 { - return fmt.Errorf("failure scraping trakt username") - } - tc.config.username = hrefPieces[2] - return nil + return selectorExists(response.Body, "a[href='/logout']") } func (tc *TraktClient) GetAccessToken(deviceCode string) (*entities.TraktAuthTokensResponse, error) { @@ -330,6 +327,20 @@ func (tc *TraktClient) doRequest(requestFields requestFields) (*http.Response, e return nil, fmt.Errorf("reached max retry attempts for %s %s", request.Method, request.URL) } +func (tc *TraktClient) UserInfoGet() (*entities.TraktUserInfo, error) { + response, err := tc.doRequest(requestFields{ + Method: http.MethodGet, + BasePath: traktPathBaseAPI, + Endpoint: traktPathUserInfo, + Body: http.NoBody, + Headers: tc.defaultApiHeaders(), + }) + if err != nil { + return nil, err + } + return decodeReader[*entities.TraktUserInfo](response.Body) +} + func (tc *TraktClient) WatchlistGet() (*entities.TraktList, error) { response, err := tc.doRequest(requestFields{ Method: http.MethodGet, diff --git a/pkg/client/trakt_test.go b/pkg/client/trakt_test.go index e38063f..0fec9eb 100644 --- a/pkg/client/trakt_test.go +++ b/pkg/client/trakt_test.go @@ -1804,7 +1804,7 @@ func TestTraktClient_ActivateAuthorize(t *testing.T) { httpmock.RegisterResponder( http.MethodPost, traktPathBaseBrowser+traktPathActivateAuthorize, - httpmock.NewStringResponder(http.StatusOK, `Profile`), + httpmock.NewStringResponder(http.StatusOK, `Sign Out`), ) }, assertions: func(assertions *assert.Assertions, err error) { @@ -1831,7 +1831,7 @@ func TestTraktClient_ActivateAuthorize(t *testing.T) { }, }, { - name: "failure scraping username", + name: "failure finding logout selector", args: args{ authenticityToken: dummyAuthenticityToken, }, @@ -1839,29 +1839,11 @@ func TestTraktClient_ActivateAuthorize(t *testing.T) { httpmock.RegisterResponder( http.MethodPost, traktPathBaseBrowser+traktPathActivateAuthorize, - httpmock.NewJsonResponderOrPanic(http.StatusOK, nil), - ) - }, - assertions: func(assertions *assert.Assertions, err error) { - assertions.Error(err) - assertions.Contains(err.Error(), "failure scraping") - }, - }, - { - name: "failure parsing scrape result to username", - args: args{ - authenticityToken: dummyAuthenticityToken, - }, - requirements: func() { - httpmock.RegisterResponder( - http.MethodPost, - traktPathBaseBrowser+traktPathActivateAuthorize, - httpmock.NewStringResponder(http.StatusOK, `Profile`), + httpmock.NewStringResponder(http.StatusOK, ""), ) }, assertions: func(assertions *assert.Assertions, err error) { - assertions.Error(err) - assertions.Contains(err.Error(), "failure scraping") + assertions.Contains(err.Error(), "failure finding selector") }, }, } @@ -2027,13 +2009,18 @@ func TestTraktClient_Hydrate(t *testing.T) { httpmock.RegisterResponder( http.MethodPost, traktPathBaseBrowser+traktPathActivateAuthorize, - httpmock.NewStringResponder(http.StatusOK, `Profile`), + httpmock.NewStringResponder(http.StatusOK, `Sign Out`), ) httpmock.RegisterResponder( http.MethodPost, traktPathBaseAPI+traktPathAuthTokens, httpmock.NewStringResponder(http.StatusOK, `{"access_token":"access-token-value"}`), ) + httpmock.RegisterResponder( + http.MethodGet, + traktPathBaseAPI+traktPathUserInfo, + httpmock.NewStringResponder(http.StatusOK, `{"username":"cecobask"}`), + ) }, assertions: func(assertions *assert.Assertions, err error) { assertions.NoError(err) @@ -2229,7 +2216,7 @@ func TestTraktClient_Hydrate(t *testing.T) { httpmock.RegisterResponder( http.MethodPost, traktPathBaseBrowser+traktPathActivateAuthorize, - httpmock.NewStringResponder(http.StatusOK, `Profile`), + httpmock.NewStringResponder(http.StatusOK, `Sign Out`), ) httpmock.RegisterResponder( http.MethodPost, @@ -2242,6 +2229,55 @@ func TestTraktClient_Hydrate(t *testing.T) { assertions.Contains(err.Error(), "failure exchanging trakt device code for access token") }, }, + { + name: "failure getting user info", + requirements: func() { + httpmock.RegisterResponder( + http.MethodPost, + traktPathBaseAPI+traktPathAuthCodes, + httpmock.NewStringResponder(http.StatusOK, `{"device_code":"`+dummyDeviceCode+`","user_code":"`+dummyUserCode+`"}`), + ) + httpmock.RegisterResponder( + http.MethodGet, + traktPathBaseBrowser+traktPathAuthSignIn, + httpmock.NewStringResponder(http.StatusOK, `
`), + ) + httpmock.RegisterResponder( + http.MethodPost, + traktPathBaseBrowser+traktPathAuthSignIn, + httpmock.NewJsonResponderOrPanic(http.StatusOK, nil), + ) + httpmock.RegisterResponder( + http.MethodGet, + traktPathBaseBrowser+traktPathActivate, + httpmock.NewStringResponder(http.StatusOK, `
`), + ) + httpmock.RegisterResponder( + http.MethodPost, + traktPathBaseBrowser+traktPathActivate, + httpmock.NewStringResponder(http.StatusOK, `
`), + ) + httpmock.RegisterResponder( + http.MethodPost, + traktPathBaseBrowser+traktPathActivateAuthorize, + httpmock.NewStringResponder(http.StatusOK, `Sign Out`), + ) + httpmock.RegisterResponder( + http.MethodPost, + traktPathBaseAPI+traktPathAuthTokens, + httpmock.NewStringResponder(http.StatusOK, `{"access_token":"access-token-value"}`), + ) + httpmock.RegisterResponder( + http.MethodGet, + traktPathBaseAPI+traktPathUserInfo, + httpmock.NewJsonResponderOrPanic(http.StatusInternalServerError, nil), + ) + }, + assertions: func(assertions *assert.Assertions, err error) { + assertions.Error(err) + assertions.Contains(err.Error(), "failure getting trakt user info") + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {