From c0a43bc361398bd551bfdc1cf7f64076754e7f35 Mon Sep 17 00:00:00 2001 From: Mishig Date: Wed, 17 Apr 2024 10:50:10 +0200 Subject: [PATCH] [gguf & st] parse shard filenames in typed function (#631) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit follow up to https://github.com/huggingface/huggingface.js/pull/627#pullrequestreview-2000893065 > why not a typed function fn(filename) that returns all the data in a typed manner instead of the regex 🙈 . > Currently in the next version of HF.js we can change the regex to something else, removing or renaming the groups, and on moon's side we'd have no clue. No compiler warning or anything. --- packages/gguf/src/gguf.spec.ts | 11 +++++------ packages/gguf/src/gguf.ts | 18 ++++++++++++++++++ .../lib/parse-safetensors-metadata.spec.ts | 13 ++++++------- .../hub/src/lib/parse-safetensors-metadata.ts | 19 +++++++++++++++++++ 4 files changed, 48 insertions(+), 13 deletions(-) diff --git a/packages/gguf/src/gguf.spec.ts b/packages/gguf/src/gguf.spec.ts index bd5b76155..87c6edce3 100644 --- a/packages/gguf/src/gguf.spec.ts +++ b/packages/gguf/src/gguf.spec.ts @@ -1,5 +1,5 @@ import { describe, expect, it } from "vitest"; -import { GGMLQuantizationType, RE_GGUF_SHARD_FILE, gguf } from "./gguf"; +import { GGMLQuantizationType, gguf, parseGgufShardFile } from "./gguf"; const URL_LLAMA = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/191239b/llama-2-7b-chat.Q2_K.gguf"; const URL_MISTRAL_7B = @@ -223,11 +223,10 @@ describe("gguf", () => { it("should detect sharded gguf filename", async () => { const ggufPath = "grok-1/grok-1-q4_0-00003-of-00009.gguf"; // https://huggingface.co/ggml-org/models/blob/fcf344adb9686474c70e74dd5e55465e9e6176ef/grok-1/grok-1-q4_0-00003-of-00009.gguf - const match = ggufPath.match(RE_GGUF_SHARD_FILE); + const ggufShardFileInfo = parseGgufShardFile(ggufPath); - expect(RE_GGUF_SHARD_FILE.test(ggufPath)).toEqual(true); - expect(match?.groups?.prefix).toEqual("grok-1/grok-1-q4_0"); - expect(match?.groups?.shard).toEqual("00003"); - expect(match?.groups?.total).toEqual("00009"); + expect(ggufShardFileInfo?.prefix).toEqual("grok-1/grok-1-q4_0"); + expect(ggufShardFileInfo?.shard).toEqual("00003"); + expect(ggufShardFileInfo?.total).toEqual("00009"); }); }); diff --git a/packages/gguf/src/gguf.ts b/packages/gguf/src/gguf.ts index 1e3bebf4c..cb3e0b5ee 100644 --- a/packages/gguf/src/gguf.ts +++ b/packages/gguf/src/gguf.ts @@ -8,6 +8,24 @@ export { GGUF_QUANT_DESCRIPTIONS } from "./quant-descriptions"; export const RE_GGUF_FILE = /\.gguf$/; export const RE_GGUF_SHARD_FILE = /^(?.*?)-(?\d{5})-of-(?\d{5})\.gguf$/; +export interface GgufShardFileInfo { + prefix: string; + shard: string; + total: string; +} + +export function parseGgufShardFile(filename: string): GgufShardFileInfo | null { + const match = RE_GGUF_SHARD_FILE.exec(filename); + if (match && match.groups) { + return { + prefix: match.groups["prefix"], + shard: match.groups["shard"], + total: match.groups["total"], + }; + } + return null; +} + const isVersion = (version: number): version is Version => version === 1 || version === 2 || version === 3; /** diff --git a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts index 2d8a49648..b5fad6c39 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.spec.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.spec.ts @@ -1,5 +1,5 @@ import { assert, it, describe } from "vitest"; -import { RE_SAFETENSORS_SHARD_FILE, parseSafetensorsMetadata } from "./parse-safetensors-metadata"; +import { parseSafetensorsMetadata, parseSafetensorsShardFile } from "./parse-safetensors-metadata"; import { sum } from "../utils/sum"; describe("parseSafetensorsMetadata", () => { @@ -112,12 +112,11 @@ describe("parseSafetensorsMetadata", () => { it("should detect sharded safetensors filename", async () => { const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors - const match = safetensorsFilename.match(RE_SAFETENSORS_SHARD_FILE); + const safetensorsShardFileInfo = parseSafetensorsShardFile(safetensorsFilename); - assert.strictEqual(RE_SAFETENSORS_SHARD_FILE.test(safetensorsFilename), true); - assert.strictEqual(match?.groups?.prefix, "model_"); - assert.strictEqual(match?.groups?.basePrefix, "model"); - assert.strictEqual(match?.groups?.shard, "00005"); - assert.strictEqual(match?.groups?.total, "00072"); + assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_"); + assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model"); + assert.strictEqual(safetensorsShardFileInfo?.shard, "00005"); + assert.strictEqual(safetensorsShardFileInfo?.total, "00072"); }); }); diff --git a/packages/hub/src/lib/parse-safetensors-metadata.ts b/packages/hub/src/lib/parse-safetensors-metadata.ts index c8fa05b01..e1d6786c7 100644 --- a/packages/hub/src/lib/parse-safetensors-metadata.ts +++ b/packages/hub/src/lib/parse-safetensors-metadata.ts @@ -16,6 +16,25 @@ export const RE_SAFETENSORS_FILE = /\.safetensors$/; export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; export const RE_SAFETENSORS_SHARD_FILE = /^(?(?.*?)[_-])(?\d{5})-of-(?\d{5})\.safetensors$/; +export interface SafetensorsShardFileInfo { + prefix: string; + basePrefix: string; + shard: string; + total: string; +} +export function parseSafetensorsShardFile(filename: string): SafetensorsShardFileInfo | null { + const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); + if (match && match.groups) { + return { + prefix: match.groups["prefix"], + basePrefix: match.groups["basePrefix"], + shard: match.groups["shard"], + total: match.groups["total"], + }; + } + return null; +} + const PARALLEL_DOWNLOADS = 20; const MAX_HEADER_LENGTH = 25_000_000;