forked from zilliztech/akcio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
99 lines (82 loc) · 3.29 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import sys
from typing import Optional, List
from .vector_store.milvus import VectorStore, Embeddings
from .memory_store.pg import MemoryStore
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from config import USE_SCALAR # pylint: disable=C0413
if USE_SCALAR:
from .scalar_store.es import ScalarStore
class DocStore:
'''Integrate vector store and scalar store.'''
def __init__(
self,
table_name: str,
embedding_func: Embeddings = None,
use_scalar: bool = USE_SCALAR
) -> None:
self.table_name = table_name
self.use_scalar = use_scalar
self.embedding_func = embedding_func
self.vector_db = VectorStore(
table_name=table_name, embedding_func=self.embedding_func)
if self.use_scalar:
self.scalar_db = ScalarStore(index_name=table_name)
else:
self.scalar_db = None
def search(self, query: str):
res = []
pages = []
for doc in self.vector_db.search(query):
if doc.page_content not in pages:
res.append(doc)
pages.append(doc.page_content)
if self.scalar_db:
for doc in self.scalar_db.search(query):
if doc.page_content not in pages:
res.append(doc)
pages.append(doc.page_content)
return res
def insert(self, data: List[str], metadatas: Optional[List[dict]] = None):
vec_count = None
scalar_count = None
vec_count = self.vector_db.insert(data=data, metadatas=metadatas)
if metadatas and 'doc' in metadatas[0]:
data = [doc['doc'] for doc in metadatas]
if self.scalar_db:
scalar_count = self.scalar_db.insert(data=data)
if vec_count and scalar_count:
assert vec_count == scalar_count, f'Data count does not match: {vec_count} in vector db VS {scalar_count} in scalar db.'
return vec_count
def insert_embeddings(self, data: List[float], metadatas: List[dict]):
vec_count = None
scalar_count = None
docs = []
for d in metadatas:
assert 'text' in d, 'Embedding insert must have corresponding text in metadatas.'
if 'doc' in d:
docs.append(d['doc'])
else:
docs.append(d['text'])
vec_count = self.vector_db.insert_embeddings(
data=data, metadatas=metadatas)
if self.scalar_db:
scalar_count = self.scalar_db.insert(data=docs)
if vec_count and scalar_count:
assert vec_count == scalar_count, f'Data count does not match: {vec_count} in vector db VS {scalar_count} in scalar db.'
return vec_count
@staticmethod
def drop(project):
status = DocStore.has_project(project)
assert status, f'No table found for project: {project}'
VectorStore.drop(project)
if USE_SCALAR:
ScalarStore.drop(project)
status = DocStore.has_project(project)
assert not status, f'Failed to drop table for project: {project}'
@staticmethod
def has_project(project):
status = VectorStore.has_project(project)
if USE_SCALAR:
assert ScalarStore.has_project(project) == status
return status