-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathembeddingUtils.js
56 lines (48 loc) · 1.54 KB
/
embeddingUtils.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import { env, pipeline, AutoTokenizer } from '@huggingface/transformers';
import { LRUCache } from 'lru-cache';
let tokenizer;
let generateEmbedding;
const embeddingCache = new LRUCache({
max: 500,
maxSize: 50_000_000,
sizeCalculation: (value, key) => {
return (value.length * 4) + key.length;
},
ttl: 1000 * 60 * 60,
});
// --------------------------------------------
// -- Initialize embedding model and tokenizer --
// --------------------------------------------
export async function initializeEmbeddingUtils(
onnxEmbeddingModel,
dtype = 'fp32',
localModelPath = null,
modelCacheDir = null
) {
// Configure environment
env.allowRemoteModels = true;
if (localModelPath) env.localModelPath = localModelPath;
if (modelCacheDir) env.cacheDir = modelCacheDir;
tokenizer = await AutoTokenizer.from_pretrained(onnxEmbeddingModel);
generateEmbedding = await pipeline('feature-extraction', onnxEmbeddingModel, {
dtype: dtype,
});
embeddingCache.clear();
return {
modelName: onnxEmbeddingModel,
dtype: dtype
};
}
// -------------------------------------
// -- Function to generate embeddings --
// -------------------------------------
export async function createEmbedding(text) {
const cached = embeddingCache.get(text);
if (cached) {
return cached;
}
const embeddings = await generateEmbedding(text, { pooling: 'mean', normalize: true });
embeddingCache.set(text, embeddings.data);
return embeddings.data;
}
export { tokenizer };