Skip to content

Commit

Permalink
update the model client
Browse files Browse the repository at this point in the history
  • Loading branch information
Alleria1809 committed Jun 29, 2024
1 parent 074751a commit 6ed5ca3
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 97 deletions.
5 changes: 5 additions & 0 deletions lightrag/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"lightrag.components.model_client.transformers_client.TransformerEmbedder",
OptionalPackages.TRANSFORMERS,
)
TransformerLLM = LazyImport(
"lightrag.components.model_client.transformers_client.TransformerLLM",
OptionalPackages.TRANSFORMERS,
)
TransformersClient = LazyImport(
"lightrag.components.model_client.transformers_client.TransformersClient",
OptionalPackages.TRANSFORMERS,
Expand Down Expand Up @@ -49,6 +53,7 @@
"CohereAPIClient",
"TransformerReranker",
"TransformerEmbedder",
"TransformerLLM",
"TransformersClient",
"AnthropicAPIClient",
"GroqAPIClient",
Expand Down
68 changes: 49 additions & 19 deletions lightrag/components/model_client/transformers_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,31 +238,61 @@ def init_model(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# register the model
self.models[model_name] = self.model
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
log.info(f"Done loading model {model_name}")

# Set pad token if it's not already set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token # common fallback
self.model.config.pad_token_id = self.tokenizer.eos_token_id # ensure consistency in the model config
except Exception as e:
log.error(f"Error loading model {model_name}: {e}")
raise e

def parse_chat_completion(self, input_text: str, response: str):
parsed_response = response.replace(input_text, "").strip() # Safely handle cases where input_text might not be in response

return parsed_response if parsed_response else response

def call(self, input: str, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = False ):
model = self.models.get("HuggingFaceH4/zephyr-7b-beta", None)
if model is None:
# initialize the model
self.init_model("HuggingFaceH4/zephyr-7b-beta")
prompt = input
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.model.generate(inputs.input_ids)
response = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces)[0]
return response
def call(self, input_text: str, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = False, max_length: int = 150):
if not self.model:
log.error("Model is not initialized.")
raise ValueError("Model is not initialized.")

# Ensure tokenizer has pad token; set it if not
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.config.pad_token_id = self.tokenizer.eos_token_id # Sync model config pad token id

# Process inputs with attention mask and padding
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True).to(self.device)
# inputs = self.tokenizer(input_text, return_tensors="pt", padding="longest", truncation=True).to(self.device)

with torch.no_grad(): # Ensures no gradients are calculated to save memory and computations
generate_ids = self.model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_length # Control the output length more precisely
)
response = self.tokenizer.decode(generate_ids[0], skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces)
parsed_response = self.parse_chat_completion(input_text, response)
return parsed_response

def __call__(self, **kwargs):
if "model" not in kwargs:
raise ValueError("model is required")
model_name = kwargs["model"]
if model_name == "HuggingFaceH4/zephyr-7b-beta":
return self.call(kwargs["input"])
else:
raise ValueError(f"model {model_name} is not supported")
def __call__(self, input_text: str, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = False, max_length: int = 150):
return self.call(input_text, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, max_length=max_length)


# def call(self, input_text: str, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = False):
# if not self.model:
# log.error("Model is not initialized.")
# raise ValueError("Model is not initialized.")

# inputs = self.tokenizer(input_text, return_tensors="pt")
# generate_ids = self.model.generate(inputs.input_ids, max_length=30)
# response = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces)[0]
# return response

# def __call__(self, input_text: str, skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = False):
# return self.call(input_text, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces)


class TransformersClient(ModelClient):
Expand Down
182 changes: 104 additions & 78 deletions lightrag/tests/test_transformer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from lightrag.components.model_client import (
TransformersClient,
TransformerReranker,
TransformerLLM,
TransformerEmbedder,
)
from lightrag.core.types import ModelType
Expand All @@ -22,81 +23,106 @@ def setUp(self) -> None:
"The red panda (Ailurus fulgens), also called the lesser panda, the red bear-cat, and the red cat-bear, is a mammal native to the eastern Himalayas and southwestern China.",
]

def test_transformer_embedder(self):
transformer_embedder_model = "thenlper/gte-base"
transformer_embedder_model_component = TransformerEmbedder(
model_name=transformer_embedder_model
)
print(
f"Testing transformer embedder with model {transformer_embedder_model_component}"
)
print("Testing transformer embedder")
output = transformer_embedder_model_component(
model=transformer_embedder_model, input="Hello world"
)
print(output)

def test_transformer_client(self):
transformer_client = TransformersClient()
print("Testing transformer client")
# run the model
kwargs = {
"model": "thenlper/gte-base",
# "mock": False,
}
api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
input="Hello world",
model_kwargs=kwargs,
model_type=ModelType.EMBEDDER,
)
# print(api_kwargs)
output = transformer_client.call(
api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
)

# print(transformer_client)
# print(output)

def test_transformer_reranker(self):
transformer_reranker_model = "BAAI/bge-reranker-base"
transformer_reranker_model_component = TransformerReranker()
# print(
# f"Testing transformer reranker with model {transformer_reranker_model_component}"
# )

model_kwargs = {
"model": transformer_reranker_model,
"documents": self.documents,
"query": self.query,
"top_k": 2,
}

output = transformer_reranker_model_component(
**model_kwargs,
)
# assert output is a list of float with length 2
self.assertEqual(len(output), 2)
self.assertEqual(type(output[0]), float)

def test_transformer_reranker_client(self):
transformer_reranker_client = TransformersClient(
model_name="BAAI/bge-reranker-base"
)
print("Testing transformer reranker client")
# run the model
kwargs = {
"model": "BAAI/bge-reranker-base",
"documents": self.documents,
"top_k": 2,
}
api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
input=self.query,
model_kwargs=kwargs,
model_type=ModelType.RERANKER,
)
print(api_kwargs)
self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
output = transformer_reranker_client.call(
api_kwargs=api_kwargs, model_type=ModelType.RERANKER
)
self.assertEqual(type(output), tuple)
# def test_transformer_embedder(self):
# transformer_embedder_model = "thenlper/gte-base"
# transformer_embedder_model_component = TransformerEmbedder(
# model_name=transformer_embedder_model
# )
# print(
# f"Testing transformer embedder with model {transformer_embedder_model_component}"
# )
# print("Testing transformer embedder")
# output = transformer_embedder_model_component(
# model=transformer_embedder_model, input="Hello world"
# )
# print(output)

# def test_transformer_client(self):
# transformer_client = TransformersClient()
# print("Testing transformer client")
# # run the model
# kwargs = {
# "model": "thenlper/gte-base",
# # "mock": False,
# }
# api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
# input="Hello world",
# model_kwargs=kwargs,
# model_type=ModelType.EMBEDDER,
# )
# # print(api_kwargs)
# output = transformer_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
# )

# # print(transformer_client)
# # print(output)

# def test_transformer_reranker(self):
# transformer_reranker_model = "BAAI/bge-reranker-base"
# transformer_reranker_model_component = TransformerReranker()
# # print(
# # f"Testing transformer reranker with model {transformer_reranker_model_component}"
# # )

# model_kwargs = {
# "model": transformer_reranker_model,
# "documents": self.documents,
# "query": self.query,
# "top_k": 2,
# }

# output = transformer_reranker_model_component(
# **model_kwargs,
# )
# # assert output is a list of float with length 2
# self.assertEqual(len(output), 2)
# self.assertEqual(type(output[0]), float)

# def test_transformer_reranker_client(self):
# transformer_reranker_client = TransformersClient(
# model_name="BAAI/bge-reranker-base"
# )
# print("Testing transformer reranker client")
# # run the model
# kwargs = {
# "model": "BAAI/bge-reranker-base",
# "documents": self.documents,
# "top_k": 2,
# }
# api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
# input=self.query,
# model_kwargs=kwargs,
# model_type=ModelType.RERANKER,
# )
# print(api_kwargs)
# self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
# output = transformer_reranker_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.RERANKER
# )
# self.assertEqual(type(output), tuple)


def test_transformer_llm_response(self):
"""Test the TransformerLLM model with zephyr-7b-beta for generating a response."""
transformer_llm_model = "HuggingFaceH4/zephyr-7b-beta"
transformer_llm_model_component = TransformerLLM(model_name=transformer_llm_model)

# Define a sample input
input_text = "Hello, what's the weather today?"

# Test generating a response, providing the 'model' keyword
# response = transformer_llm_model_component(input=input_text, model=transformer_llm_model)
response = transformer_llm_model_component(input_text=input_text)


# Check if the response is valid
self.assertIsInstance(response, str, "The response should be a string.")
self.assertTrue(len(response) > 0, "The response should not be empty.")

# Optionally, print the response for visual verification during testing
print(f"Generated response: {response}")


if __name__ == '__main__':
unittest.main()

0 comments on commit 6ed5ca3

Please sign in to comment.