From 186ab738e2f9c7c3613330d45e44848186958815 Mon Sep 17 00:00:00 2001 From: "Eliott C." Date: Sat, 5 Oct 2024 13:15:18 +0200 Subject: [PATCH] checkRepoAccess added (#947) More helpers for https://github.com/huggingface/huggingface.js/issues/945 Introduces `checkRepoAccess` to check if user has read access to repo --- packages/hub/README.md | 23 +++++++------ .../hub/src/lib/check-repo-access.spec.ts | 34 +++++++++++++++++++ packages/hub/src/lib/check-repo-access.ts | 32 +++++++++++++++++ packages/hub/src/lib/dataset-info.ts | 6 +--- packages/hub/src/lib/index.ts | 1 + packages/hub/src/lib/model-info.ts | 6 +--- packages/hub/src/lib/space-info.ts | 6 +--- 7 files changed, 82 insertions(+), 26 deletions(-) create mode 100644 packages/hub/src/lib/check-repo-access.spec.ts create mode 100644 packages/hub/src/lib/check-repo-access.ts diff --git a/packages/hub/README.md b/packages/hub/README.md index 6056dd638..c9a650a5c 100644 --- a/packages/hub/README.md +++ b/packages/hub/README.md @@ -30,22 +30,23 @@ For some of the calls, you need to create an account and generate an [access tok Learn how to find free models using the hub package in this [interactive tutorial](https://scrimba.com/scrim/c7BbVPcd?pl=pkVnrP7uP). ```ts -import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI, modelInfo, listModels } from "@huggingface/hub"; +import * as hub from "@huggingface/hub"; import type { RepoDesignation } from "@huggingface/hub"; const repo: RepoDesignation = { type: "model", name: "myname/some-model" }; -const {name: username} = await whoAmI({accessToken: "hf_..."}); +const {name: username} = await hub.whoAmI({accessToken: "hf_..."}); -for await (const model of listModels({search: {owner: username}, accessToken: "hf_..."})) { +for await (const model of hub.listModels({search: {owner: username}, accessToken: "hf_..."})) { console.log("My model:", model); } -const specificModel = await modelInfo({name: "openai-community/gpt2"}); +const specificModel = await hub.modelInfo({name: "openai-community/gpt2"}); +await hub.checkRepoAccess({repo, accessToken: "hf_..."}); -await createRepo({ repo, accessToken: "hf_...", license: "mit" }); +await hub.createRepo({ repo, accessToken: "hf_...", license: "mit" }); -await uploadFiles({ +await hub.uploadFiles({ repo, accessToken: "hf_...", files: [ @@ -69,7 +70,7 @@ await uploadFiles({ // or -for await (const progressEvent of await uploadFilesWithProgress({ +for await (const progressEvent of await hub.uploadFilesWithProgress({ repo, accessToken: "hf_...", files: [ @@ -79,15 +80,15 @@ for await (const progressEvent of await uploadFilesWithProgress({ console.log(progressEvent); } -await deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"}); +await hub.deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"}); -await (await downloadFile({ repo, path: "README.md" })).text(); +await (await hub.downloadFile({ repo, path: "README.md" })).text(); -for await (const fileInfo of listFiles({repo})) { +for await (const fileInfo of hub.listFiles({repo})) { console.log(fileInfo); } -await deleteRepo({ repo, accessToken: "hf_..." }); +await hub.deleteRepo({ repo, accessToken: "hf_..." }); ``` ## OAuth Login diff --git a/packages/hub/src/lib/check-repo-access.spec.ts b/packages/hub/src/lib/check-repo-access.spec.ts new file mode 100644 index 000000000..12ad5cd92 --- /dev/null +++ b/packages/hub/src/lib/check-repo-access.spec.ts @@ -0,0 +1,34 @@ +import { assert, describe, expect, it } from "vitest"; +import { checkRepoAccess } from "./check-repo-access"; +import { HubApiError } from "../error"; +import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts"; + +describe("checkRepoAccess", () => { + it("should throw 401 when accessing unexisting repo unauthenticated", async () => { + try { + await checkRepoAccess({ repo: { name: "i--d/dont", type: "model" } }); + assert(false, "should have thrown"); + } catch (err) { + expect(err).toBeInstanceOf(HubApiError); + expect((err as HubApiError).statusCode).toBe(401); + } + }); + + it("should throw 404 when accessing unexisting repo authenticated", async () => { + try { + await checkRepoAccess({ + repo: { name: "i--d/dont", type: "model" }, + hubUrl: TEST_HUB_URL, + accessToken: TEST_ACCESS_TOKEN, + }); + assert(false, "should have thrown"); + } catch (err) { + expect(err).toBeInstanceOf(HubApiError); + expect((err as HubApiError).statusCode).toBe(404); + } + }); + + it("should not throw when accessing public repo", async () => { + await checkRepoAccess({ repo: { name: "openai-community/gpt2", type: "model" } }); + }); +}); diff --git a/packages/hub/src/lib/check-repo-access.ts b/packages/hub/src/lib/check-repo-access.ts new file mode 100644 index 000000000..3107c9bd7 --- /dev/null +++ b/packages/hub/src/lib/check-repo-access.ts @@ -0,0 +1,32 @@ +import { HUB_URL } from "../consts"; +// eslint-disable-next-line @typescript-eslint/no-unused-vars +import { createApiError, type HubApiError } from "../error"; +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; + +/** + * Check if we have read access to a repository. + * + * Throw a {@link HubApiError} error if we don't have access. HubApiError.statusCode will be 401, 403 or 404. + */ +export async function checkRepoAccess( + params: { + repo: RepoDesignation; + hubUrl?: string; + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = params && checkCredentials(params); + const repoId = toRepoId(params.repo); + + const response = await (params.fetch || fetch)(`${params?.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}`, { + headers: { + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + }, + }); + + if (!response.ok) { + throw await createApiError(response); + } +} diff --git a/packages/hub/src/lib/dataset-info.ts b/packages/hub/src/lib/dataset-info.ts index a18022b10..42f479d04 100644 --- a/packages/hub/src/lib/dataset-info.ts +++ b/packages/hub/src/lib/dataset-info.ts @@ -13,10 +13,6 @@ export async function datasetInfo< name: string; hubUrl?: string; additionalFields?: T[]; - /** - * Set to limit the number of models returned. - */ - limit?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ @@ -41,7 +37,7 @@ export async function datasetInfo< ); if (!response.ok) { - createApiError(response); + throw await createApiError(response); } const data = await response.json(); diff --git a/packages/hub/src/lib/index.ts b/packages/hub/src/lib/index.ts index 070667ef3..603674ea0 100644 --- a/packages/hub/src/lib/index.ts +++ b/packages/hub/src/lib/index.ts @@ -1,4 +1,5 @@ export * from "./cache-management"; +export * from "./check-repo-access"; export * from "./commit"; export * from "./count-commits"; export * from "./create-repo"; diff --git a/packages/hub/src/lib/model-info.ts b/packages/hub/src/lib/model-info.ts index e551d9389..2744905c0 100644 --- a/packages/hub/src/lib/model-info.ts +++ b/packages/hub/src/lib/model-info.ts @@ -13,10 +13,6 @@ export async function modelInfo< name: string; hubUrl?: string; additionalFields?: T[]; - /** - * Set to limit the number of models returned. - */ - limit?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ @@ -41,7 +37,7 @@ export async function modelInfo< ); if (!response.ok) { - createApiError(response); + throw await createApiError(response); } const data = await response.json(); diff --git a/packages/hub/src/lib/space-info.ts b/packages/hub/src/lib/space-info.ts index d6dcefc18..a1b5516a6 100644 --- a/packages/hub/src/lib/space-info.ts +++ b/packages/hub/src/lib/space-info.ts @@ -14,10 +14,6 @@ export async function spaceInfo< name: string; hubUrl?: string; additionalFields?: T[]; - /** - * Set to limit the number of models returned. - */ - limit?: number; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ @@ -42,7 +38,7 @@ export async function spaceInfo< ); if (!response.ok) { - createApiError(response); + throw await createApiError(response); } const data = await response.json();