Skip to content

Commit

Permalink
sync with main
Browse files Browse the repository at this point in the history
  • Loading branch information
Alleria1809 committed Jul 2, 2024
1 parent 3d02475 commit 2d818ee
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 218 deletions.
6 changes: 3 additions & 3 deletions lightrag/lightrag/components/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .react import DEFAULT_REACT_AGENT_SYSTEM_PROMPT, ReActAgent
from .react import DEFAULT_REACT_AGENT_SYSTEM_PROMPT, ReactAgent
from lightrag.utils.registry import EntityMapping

__all__ = [
"ReActAgent",
"ReactAgent",
"DEFAULT_REACT_AGENT_SYSTEM_PROMPT",
]

for name in __all__:
EntityMapping.register(name, globals()[name])
EntityMapping.register(name, globals()[name])
6 changes: 3 additions & 3 deletions lightrag/lightrag/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
func=self._finish,
answer="final answer: 'answer'",
)
output_parser = JsonOutputParser(data_class=ouput_data_class, examples=example)
output_parser = JsonOutputParser(data_class=ouput_data_class, example=example)
prompt_kwargs = {
"tools": self.tool_manager.yaml_definitions,
"output_format_str": output_parser.format_instructions(),
Expand Down Expand Up @@ -320,7 +320,7 @@ def _extra_repr(self) -> str:


if __name__ == "__main__":
from lightrag.components.model_client import GroqAPIClient
from components.model_client import GroqAPIClient
from lightrag.core.types import ModelClientType
from lightrag.utils import setup_env # noqa

Expand Down Expand Up @@ -424,4 +424,4 @@ def search(query: str) -> str:
answer_no_agent = generator(prompt_kwargs={"input_str": query})
print(f"Answer with agent: {answer}")
print(f"Answer without agent: {answer_no_agent}")
print(f"Average time: {average_time / len(queries)}")
print(f"Average time: {average_time / len(queries)}")
199 changes: 101 additions & 98 deletions lightrag/tests/test_transformer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,101 +16,104 @@ def setUp(self) -> None:
"The red panda (Ailurus fulgens), also called the lesser panda, the red bear-cat, and the red cat-bear, is a mammal native to the eastern Himalayas and southwestern China.",
]

def test_transformer_embedder(self):
transformer_embedder_model = "thenlper/gte-base"
transformer_embedder_model_component = TransformerEmbedder(
model_name=transformer_embedder_model
)
print(
f"Testing transformer embedder with model {transformer_embedder_model_component}"
)
print("Testing transformer embedder")
output = transformer_embedder_model_component(
model=transformer_embedder_model, input="Hello world"
)
print(output)

def test_transformer_client(self):
transformer_client = TransformersClient()
print("Testing transformer client")
# run the model
kwargs = {
"model": "thenlper/gte-base",
# "mock": False,
}
api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
input="Hello world",
model_kwargs=kwargs,
model_type=ModelType.EMBEDDER,
)
# print(api_kwargs)
output = transformer_client.call(
api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
)

# print(transformer_client)
# print(output)

def test_transformer_reranker(self):
transformer_reranker_model = "BAAI/bge-reranker-base"
transformer_reranker_model_component = TransformerReranker()
# print(
# f"Testing transformer reranker with model {transformer_reranker_model_component}"
# )

model_kwargs = {
"model": transformer_reranker_model,
"documents": self.documents,
"query": self.query,
"top_k": 2,
}

output = transformer_reranker_model_component(
**model_kwargs,
)
# assert output is a list of float with length 2
self.assertEqual(len(output), 2)
self.assertEqual(type(output[0]), float)

def test_transformer_reranker_client(self):
transformer_reranker_client = TransformersClient(
model_name="BAAI/bge-reranker-base"
)
print("Testing transformer reranker client")
# run the model
kwargs = {
"model": "BAAI/bge-reranker-base",
"documents": self.documents,
"top_k": 2,
}
api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
input=self.query,
model_kwargs=kwargs,
model_type=ModelType.RERANKER,
)
print(api_kwargs)
self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
output = transformer_reranker_client.call(
api_kwargs=api_kwargs, model_type=ModelType.RERANKER
)
self.assertEqual(type(output), tuple)


def test_transformer_llm_response(self):
"""Test the TransformerLLM model with zephyr-7b-beta for generating a response."""
transformer_llm_model = "HuggingFaceH4/zephyr-7b-beta"
transformer_llm_model_component = TransformerLLM(model_name=transformer_llm_model)

# Define a sample input
input_text = "Hello, what's the weather today?"

response = transformer_llm_model_component(input_text=input_text)

# Check if the response is valid
self.assertIsInstance(response, str, "The response should be a string.")
self.assertTrue(len(response) > 0, "The response should not be empty.")

# Optionally, print the response for visual verification during testing
print(f"Generated response: {response}")


# def test_transformer_embedder(self):
# transformer_embedder_model = "thenlper/gte-base"
# transformer_embedder_model_component = TransformerEmbedder(
# model_name=transformer_embedder_model
# )
# print(
# f"Testing transformer embedder with model {transformer_embedder_model_component}"
# )
# print("Testing transformer embedder")
# output = transformer_embedder_model_component(
# model=transformer_embedder_model, input="Hello world"
# )
# print(output)

# def test_transformer_client(self):
# transformer_client = TransformersClient()
# print("Testing transformer client")
# # run the model
# kwargs = {
# "model": "thenlper/gte-base",
# # "mock": False,
# }
# api_kwargs = transformer_client.convert_inputs_to_api_kwargs(
# input="Hello world",
# model_kwargs=kwargs,
# model_type=ModelType.EMBEDDER,
# )
# # print(api_kwargs)
# output = transformer_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.EMBEDDER
# )

# # print(transformer_client)
# # print(output)

# def test_transformer_reranker(self):
# transformer_reranker_model = "BAAI/bge-reranker-base"
# transformer_reranker_model_component = TransformerReranker()
# # print(
# # f"Testing transformer reranker with model {transformer_reranker_model_component}"
# # )

# model_kwargs = {
# "model": transformer_reranker_model,
# "documents": self.documents,
# "query": self.query,
# "top_k": 2,
# }

# output = transformer_reranker_model_component(
# **model_kwargs,
# )
# # assert output is a list of float with length 2
# self.assertEqual(len(output), 2)
# self.assertEqual(type(output[0]), float)

# def test_transformer_reranker_client(self):
# transformer_reranker_client = TransformersClient(
# model_name="BAAI/bge-reranker-base"
# )
# print("Testing transformer reranker client")
# # run the model
# kwargs = {
# "model": "BAAI/bge-reranker-base",
# "documents": self.documents,
# "top_k": 2,
# }
# api_kwargs = transformer_reranker_client.convert_inputs_to_api_kwargs(
# input=self.query,
# model_kwargs=kwargs,
# model_type=ModelType.RERANKER,
# )
# print(api_kwargs)
# self.assertEqual(api_kwargs["model"], "BAAI/bge-reranker-base")
# output = transformer_reranker_client.call(
# api_kwargs=api_kwargs, model_type=ModelType.RERANKER
# )
# self.assertEqual(type(output), tuple)

# def test_transformer_llm_response(self):
# """Test the TransformerLLM model with zephyr-7b-beta for generating a response."""
# transformer_llm_model = "HuggingFaceH4/zephyr-7b-beta"
# transformer_llm_model_component = TransformerLLM(model_name=transformer_llm_model)

# # Define a sample input
# input_text = "Hello, what's the weather today?"

# # Test generating a response, providing the 'model' keyword
# # response = transformer_llm_model_component(input=input_text, model=transformer_llm_model)
# response = transformer_llm_model_component(input_text=input_text)

# # Check if the response is valid
# self.assertIsInstance(response, str, "The response should be a string.")
# self.assertTrue(len(response) > 0, "The response should not be empty.")

# # Optionally, print the response for visual verification during testing
# print(f"Generated response: {response}")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 2d818ee

Please sign in to comment.