This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 735
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4a96830
commit 8096a98
Showing
4 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from typing import Any, Dict, List, Optional | ||
|
||
import numpy as np | ||
from llama_index import Document, VectorStoreIndex | ||
from llama_index.bridge.pydantic import BaseModel, Field | ||
from llama_index.llama_pack.base import BaseLlamaPack | ||
from llama_index.llms import OpenAI | ||
from llama_index.llms.llm import LLM | ||
from llama_index.llms.types import ChatMessage, MessageRole | ||
from llama_index.prompts import PromptTemplate | ||
from sklearn.mixture import GaussianMixture | ||
from umap import UMAP | ||
|
||
SUMMARY_PROMPT_STR = """ | ||
<system> You are a Summarizing Text Portal | ||
<user> Write a summary of the following, including as many key details as | ||
possible: {context}: | ||
<system> | ||
""" | ||
SUMMARY_PROMPT = PromptTemplate(SUMMARY_PROMPT_STR) | ||
|
||
|
||
class ClusteringModule(BaseModel): | ||
n_components: int = Field( | ||
default=64, description="dimensionality of the reduced dimension space" | ||
) | ||
metric: str = Field( | ||
default="cosine", description="Controls how distance is computed in the ambient space of the input data" | ||
) | ||
n_neighbors_list : List[int] = Field( | ||
default=[10, 25, 50], description="Controls how UMAP balances local versus global structure in the data." | ||
) | ||
|
||
def __init__( | ||
self, n_components: int, metric: str, n_neighbors_list: List[int], max_ctx: int | ||
) -> None: | ||
self.umap_list = [ | ||
UMAP(n_neighbors=n_neighbors, n_components=n_components, metric=metric) | ||
for n_neighbors in n_neighbors_list | ||
] | ||
self.max_ctx = max_ctx | ||
|
||
def _get_optimal_clusters(embeddings: List[float]) -> List[int]: | ||
best_bic = np.inf | ||
best_clusters = None | ||
|
||
for n_clusters in range(1, 7): | ||
gmm = GaussianMixture(n_components=n_clusters) | ||
clusters = gmm.fit_predict(embeddings) | ||
bic = gmm.bic(embeddings) | ||
if bic < best_bic: | ||
best_bic = bic | ||
best_clusters = clusters | ||
|
||
return best_clusters | ||
|
||
def __call__(self, documents: List[str], embeddings: List[float]) -> List[str]: | ||
# TODO recursive clustering on clusters with documents exceeding the max_ctx | ||
summary_documents = [] | ||
for umap in self.umap_list: | ||
reduced_embeddings = umap.fit_transform(embeddings) | ||
clusters = self._get_optimal_clusters(reduced_embeddings) | ||
cluster_to_documents = {} | ||
for cluster, document in zip(clusters, documents): | ||
cluster_to_documents.get(cluster, []).append(document) | ||
|
||
summary_documents.append(cluster_to_documents.values()) | ||
|
||
return summary_documents | ||
|
||
|
||
class SummaryModule(BaseModel): | ||
llm: LLM = Field(..., description="LLM") | ||
summary_prompt: PromptTemplate = Field(default=SUMMARY_PROMPT, description="Summary prompt template") | ||
|
||
def __init__(self, llm: LLM, summary_prompt: PromptTemplate) -> None: | ||
llm = llm or OpenAI(model="gpt-3.5-turbo") | ||
self.summary_prompt = summary_prompt | ||
|
||
def __call__(self, documents_per_cluster: List[str]) -> List[str]: | ||
summary = [] | ||
for documents in documents_per_cluster: | ||
messages = [ | ||
ChatMessage( | ||
role=MessageRole.SYSTEM, content="You are a Summarizing Text Portal" | ||
), | ||
ChatMessage( | ||
role=MessageRole.USER, | ||
content=self.summary_prompt.as_query_component( | ||
partial={"context": "\n".join(documents)} | ||
), | ||
), | ||
] | ||
|
||
summary.append(self.llm.chat(messages).message.content) | ||
return summary | ||
|
||
|
||
class RaptorIndexingEngine(BaseModel): | ||
"""Raptor indexing engine.""" | ||
|
||
llm: LLM = Field(..., description="llm") | ||
index: VectorStoreIndex = Field(..., description="index") | ||
tree_depth: int = Field(default=3, description="Summary tree depth") | ||
summary_module: SummaryModule = Field(..., description="Summary module") | ||
clustering_module: ClusteringModule = Field(..., description="Clustering Module") | ||
verbose: bool = Field(default=True, description="Verbose.") | ||
|
||
def __init__( | ||
self, | ||
index: VectorStoreIndex, | ||
llm: LLM, | ||
tree_depth: int, | ||
summary_module: Optional[SummaryModule] = None, | ||
clustering_module: Optional[ClusteringModule] = None, | ||
verbose: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
super().__init__(verbose=verbose, **kwargs) | ||
self.index = index | ||
self.llm = llm or OpenAI(model="gpt-3.5-turbo") | ||
self.tree_depth = tree_depth | ||
self.summary_module = summary_module or SummaryModule(llm=llm) | ||
self.clustering_module = clustering_module or ClusteringModule() | ||
|
||
def _get_embeddings_per_level(self, level: int = 0) -> List[float]: | ||
vector_store = self.index.vector_store | ||
node_dict = self.index.index_struct.nodes_dict | ||
docstore = self.index.docstore | ||
target_vectors = node_dict.keys() | ||
|
||
if level > 0: | ||
target_vectors = [ | ||
vector_id | ||
for vector_id, document_id in node_dict.items() | ||
if docstore.get_document(document_id).metadata["level"] == level | ||
] | ||
return [vector_store.get(vector_id) for vector_id in target_vectors] | ||
|
||
def insert_abstractions(self, documents: List[Document]) -> None: | ||
for level in range(self.tree_depth): | ||
embeddings = self._get_embeddings_per_level(level=level) | ||
documents_per_cluster = self.clustering_module(documents, embeddings) | ||
summaries_per_cluster = self.summary_module(documents_per_cluster) | ||
summaries_per_cluster = [ | ||
Document(text=summary, metadata={"level": level}) | ||
for summary in summaries_per_cluster | ||
] | ||
self.index.insert_nodes(summaries_per_cluster) | ||
documents = summaries_per_cluster | ||
|
||
|
||
class RaptorPack(BaseLlamaPack): | ||
"""Raptor pack.""" | ||
|
||
def __init__( | ||
self, | ||
llm: Optional[LLM] = None, | ||
verbose: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
|
||
self.indexing_engine = RaptorIndexingEngine( | ||
llm=llm, | ||
verbose=verbose, | ||
**kwargs, | ||
) | ||
|
||
def get_modules(self) -> Dict[str, Any]: | ||
"""Get modules.""" | ||
return { | ||
"indexing_engine": self.indexing_engine, | ||
"llm": self.indexing_engine.llm, | ||
} | ||
|
||
def run(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Run the pipeline.""" | ||
return self.indexing_engine.insert_abstractions(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
umap-learn | ||
scikit-learn |