diff --git a/packages/hub/src/lib/index.ts b/packages/hub/src/lib/index.ts index 603674ea0..b79385dc5 100644 --- a/packages/hub/src/lib/index.ts +++ b/packages/hub/src/lib/index.ts @@ -19,6 +19,7 @@ export * from "./model-info"; export * from "./oauth-handle-redirect"; export * from "./oauth-login-url"; export * from "./parse-safetensors-metadata"; +export * from "./paths-info"; export * from "./space-info"; export * from "./upload-file"; export * from "./upload-files"; diff --git a/packages/hub/src/lib/paths-info.spec.ts b/packages/hub/src/lib/paths-info.spec.ts new file mode 100644 index 000000000..994d623ae --- /dev/null +++ b/packages/hub/src/lib/paths-info.spec.ts @@ -0,0 +1,75 @@ +import { expect, it, describe } from "vitest"; +import type { CommitInfo, PathInfo, SecurityFileStatus } from "./paths-info"; +import { pathsInfo } from "./paths-info"; + +describe("pathsInfo", () => { + it("should fetch LFS path info", async () => { + const result: PathInfo[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["tf_model.h5"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + expect(modelPathInfo.path).toBe('tf_model.h5'); + expect(modelPathInfo.type).toBe('file'); + // lfs pointer, therefore lfs should be defined + expect(modelPathInfo?.lfs).toBeDefined(); + expect(modelPathInfo?.lfs?.oid).toBe("a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"); + expect(modelPathInfo?.lfs?.size).toBe(536063208); + expect(modelPathInfo?.lfs?.pointerSize).toBe(134); + + // should not include expand info + expect(modelPathInfo.lastCommit).toBeUndefined(); + expect(modelPathInfo.securityFileStatus).toBeUndefined(); + }); + + it("expand parmas should fetch lastCommit and securityFileStatus", async () => { + const result: (PathInfo & { + lastCommit: CommitInfo, + securityFileStatus: SecurityFileStatus, + })[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["tf_model.h5"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + expand: true, // include + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + + // should include expand info + expect(modelPathInfo.lastCommit).toBeDefined(); + expect(modelPathInfo.securityFileStatus).toBeDefined(); + + expect(modelPathInfo.lastCommit.id).toBe("dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7"); + expect(modelPathInfo.lastCommit.title).toBe("Update tf_model.h5"); + expect(modelPathInfo.lastCommit.date.getTime()).toBe(1569268124000); // 2019-09-23T19:48:44.000Z + }); + + it("non-LFS pointer should have lfs undefined", async () => { + const result: (PathInfo)[] = await pathsInfo({ + repo: { + name: "bert-base-uncased", + type: "model", + }, + paths: ["config.json"], + revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7", + }); + + expect(result).toHaveLength(1); + + const modelPathInfo = result[0]; + expect(modelPathInfo.path).toBe("config.json"); + expect(modelPathInfo.lfs).toBeUndefined(); + }); +}); diff --git a/packages/hub/src/lib/paths-info.ts b/packages/hub/src/lib/paths-info.ts new file mode 100644 index 000000000..4c9a1de20 --- /dev/null +++ b/packages/hub/src/lib/paths-info.ts @@ -0,0 +1,120 @@ +import type { CredentialsParams, RepoDesignation } from "../types/public"; +import { checkCredentials } from "../utils/checkCredentials"; +import { toRepoId } from "../utils/toRepoId"; +import { HUB_URL } from "../consts"; +import { createApiError } from "../error"; + +export interface LfsPathInfo { + "oid": string, + "size": number, + "pointerSize": number +} + +export interface CommitInfo { + "id": string, + "title": string, + "date": Date, +} + +export interface SecurityFileStatus { + "status": string, +} + +export interface PathInfo { + path: string, + type: string, + oid: string, + size: number, + /** + * Only defined when path is LFS pointer + */ + lfs?: LfsPathInfo, + lastCommit?: CommitInfo, + securityFileStatus?: SecurityFileStatus +} + +// Define the overloaded signatures +export function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand: true; // if expand true + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise<(PathInfo & {lastCommit: CommitInfo, securityFileStatus: SecurityFileStatus })[]>; +export function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand?: boolean; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise<(PathInfo)[]>; + +export async function pathsInfo( + params: { + repo: RepoDesignation; + paths: string[]; + expand?: boolean; + revision?: string; + hubUrl?: string; + /** + * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. + */ + fetch?: typeof fetch; + } & Partial +): Promise { + const accessToken = checkCredentials(params); + const repoId = toRepoId(params.repo); + + const hubUrl = params.hubUrl ?? HUB_URL; + + const url = `${hubUrl}/api/${repoId.type}s/${repoId.name}/paths-info/${encodeURIComponent(params.revision ?? "main")}`; + + const resp = await (params.fetch ?? fetch)(url, { + method: "POST", + headers: { + ...(params.credentials && { + Authorization: `Bearer ${accessToken}`, + }), + 'Accept': 'application/json', + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + paths: params.paths, + expand: params.expand, + }), + }); + + if (!resp.ok) { + throw await createApiError(resp); + } + + const json: unknown = await resp.json(); + if(!Array.isArray(json)) throw new Error('malformed response: expected array'); + + return json.map((item: PathInfo) => ({ + path: item.path, + lfs: item.lfs, + type: item.type, + oid: item.oid, + size: item.size, + // expand fields + securityFileStatus: item.securityFileStatus, + lastCommit: item.lastCommit ? { + date: new Date(item.lastCommit.date), + title: item.lastCommit.title, + id: item.lastCommit.id, + }: undefined, + })); +}