Skip to content

Commit

Permalink
add list output parser
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed May 21, 2024
1 parent efa8b85 commit 8c3b6ae
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 57 deletions.
74 changes: 41 additions & 33 deletions core/default_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,45 @@
DEFAULT_LIGHTRAG_SYSTEM_PROMPT = r"""{# task desc #}
{% if task_desc_str %}
{{task_desc_str}}
{% endif %}
{# tools #}
{% if tools_str %}
<TOOLS>
{{tools_str}}
</TOOLS>
{% endif %}
{# example #}
{% if examples_str %}
<EXAMPLES>
{{examples_str}}
</EXAMPLES>
{% endif %}
{# chat history #}
{% if chat_history_str %}
<CHAT_HISTORY>
{{chat_history_str}}
</CHAT_HISTORY>
{% endif %}
{#contex#}
{% if context_str %}
<CONTEXT>
{{context_str}}
</CONTEXT>
{% endif %}
{# steps #}
{% if steps_str %}
<STEPS>
{{steps_str}}
</STEPS>
{% endif %}
{% if task_desc_str %}
{{task_desc_str}}
{% else %}
Answer user query.
{% endif %}
{# output format #}
{% if output_format_str %}
<OUTPUT_FORMAT>
{{output_format_str}}
</OUTPUT_FORMAT>
{% endif %}
{# tools #}
{% if tools_str %}
<TOOLS>
{{tools_str}}
</TOOLS>
{% endif %}
{# example #}
{% if examples_str %}
<EXAMPLES>
{{examples_str}}
</EXAMPLES>
{% endif %}
{# chat history #}
{% if chat_history_str %}
<CHAT_HISTORY>
{{chat_history_str}}
</CHAT_HISTORY>
{% endif %}
{#contex#}
{% if context_str %}
<CONTEXT>
{{context_str}}
</CONTEXT>
{% endif %}
{# steps #}
{% if steps_str %}
<STEPS>
{{steps_str}}
</STEPS>
{% endif %}
"""
"""This is the default system prompt template used in the LightRAG.
Expand Down
10 changes: 7 additions & 3 deletions core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
But you can replace the prompt and set any variables you want and use the preset_prompt_kwargs to fill in the variables.
"""
super().__init__()

self.model_kwargs = model_kwargs
if "model" not in model_kwargs:
raise ValueError(
Expand Down Expand Up @@ -91,6 +92,7 @@ def _extra_repr(self) -> str:
def _post_call(self, completion: Any) -> GeneratorOutputType:
r"""Parse the completion and process the output."""
response = self.model_client.parse_chat_completion(completion)
print(f"Raw response: \n{response}")
if self.output_processors:
response = self.output_processors(response)
return response
Expand Down Expand Up @@ -123,11 +125,12 @@ def call(
r"""Call the model with the input(user_query) and model_kwargs."""

api_kwargs = self._pre_call(input, prompt_kwargs, model_kwargs)
print(f"api_kwargs: {api_kwargs}")
# print(f"api_kwargs: {api_kwargs}")
completion = self.model_client.call(
api_kwargs=api_kwargs, model_type=self.model_type
)
return self._post_call(completion)
output = self._post_call(completion)
return output

async def acall(
self,
Expand All @@ -142,4 +145,5 @@ async def acall(
completion = await self.model_client.acall(
api_kwargs=api_kwargs, model_type=self.model_type
)
return self._post_call(completion)
output = self._post_call(completion)
return output
10 changes: 5 additions & 5 deletions core/string_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self, add_missing_right_bracket: bool = True):
super().__init__()
self.add_missing_right_bracket = add_missing_right_bracket

def __call__(self, text: str) -> List[Any]:
list_str = F.extract_list_str(text, self.add_missing_right_bracket)
def __call__(self, input: str) -> List[Any]:
list_str = F.extract_list_str(input, self.add_missing_right_bracket)
list_obj = F.parse_json_str_to_obj(list_str)
return list_obj

Expand All @@ -45,9 +45,9 @@ def __init__(self, add_missing_right_brace: bool = True):
super().__init__()
self.add_missing_right_brace = add_missing_right_brace

def __call__(self, text: str) -> JASON_PARSER_OUTPUT_TYPE:
text = text.strip()
json_str = F.extract_json_str(text, self.add_missing_right_brace)
def call(self, input: str) -> JASON_PARSER_OUTPUT_TYPE:
input = input.strip()
json_str = F.extract_json_str(input, self.add_missing_right_brace)
json_obj = F.parse_json_str_to_obj(json_str)
return json_obj

Expand Down
76 changes: 68 additions & 8 deletions prompts/outputs.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,67 @@
"""The most commonly used output parsers for the Generator.
Note: Even with OutputParser for output_format_str formatting and the response parsing, it is not 100% guaranteed
as user query can impact the output.
"""

from dataclasses import is_dataclass
from typing import Dict, Any, Type, Optional
from typing import Dict, Any, Type

from core.component import Component
from core.prompt_builder import Prompt
from core.functional import get_data_class_schema
from core.string_parser import YAMLParser
from core.string_parser import YAMLParser, ListParser, JsonParser

# TODO: might be worth to parse a list of yaml or json objects. For instance, a list of jokes.
# setup: Why couldn't the bicycle stand up by itself?
# punchline: Because it was two-tired.
#
# setup: What do you call a fake noodle?
# punchline: An impasta.

JSON_OUTPUT_FORMAT = r""""""
YAML_OUTPUT_FORMAT = r"""The output should be formatted as a standard YAML instance with the following JSON schema:
YAML_OUTPUT_FORMAT = r"""Your output should be formatted as a standard YAML instance with the following JSON schema:
```
{{schema}}
```
-Make sure to always enclose the YAML output in triple backticks (```). Please do not add anything other than valid YAML output!
-Follow the YAML formatting conventions with an indent of 2 spaces.
-Follow the YAML formatting conventions with an indent of 2 spaces.
"""
LIST_OUTPUT_FORMAT = r"""Your output should be formatted as a standard Python list.
-Each element can be of any Python data type such as string, integer, float, list, dictionary, etc.
-You can also have nested lists and dictionaries.
-Please do not add anything other than valid Python list output!
"""
LIST_OUTPUT_FORMAT = r""""""


YAML_OUTPUT_PARSER_OUTPUT_TYPE = Dict[str, Any]


class YAMLOutputParser(Component):
class OutputParser(Component):
__doc__ = r"""The abstract class for all output parsers.
This interface helps users to customize their output parsers with consistent interfaces for the Generator.
Even though you don't always need to subclass it.
LightRAG uses two core components:
1. the Prompt to format output instruction
2. A string parser component from core.string_parser for response parsing.
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__()
pass

def format_instructions(self) -> str:
r"""Return the formatted instructions to use in prompt for the output format."""
raise NotImplementedError("This is an abstract method.")

def call(self, input: str) -> Any:
r"""Parse the output string to the desired format and return the parsed output."""
raise NotImplementedError("This is an abstract method.")


class YAMLOutputParser(OutputParser):
__doc__ = r"""YAML output parser using dataclass for schema extraction.
Examples:
Expand Down Expand Up @@ -58,9 +99,15 @@ class YAMLOutputParser(Component):
def __init__(
self,
data_class_for_yaml: Type,
yaml_output_format_template: Optional[str] = YAML_OUTPUT_FORMAT,
output_processors: Optional[Component] = YAMLParser(),
yaml_output_format_template: str = YAML_OUTPUT_FORMAT,
output_processors: Component = YAMLParser(),
):
r"""
Args:
data_class_for_yaml (Type): The dataclass to extract the schema for the YAML output.
yaml_output_format_template (str, optional): The template for the YAML output format. Defaults to YAML_OUTPUT_FORMAT.
output_processors (Component, optional): The output processors to parse the YAML string to JSON object. Defaults to YAMLParser().
"""
super().__init__()
if not is_dataclass(data_class_for_yaml):
raise ValueError(
Expand All @@ -82,3 +129,16 @@ def call(self, input: str) -> YAML_OUTPUT_PARSER_OUTPUT_TYPE:
def _extra_repr(self) -> str:
s = f"data_class_for_yaml={self.data_class_for_yaml}"
return s


class ListOutputParser(OutputParser):
def __init__(self, list_output_format_template: str = LIST_OUTPUT_FORMAT):
super().__init__()
self.list_output_format_prompt = Prompt(template=list_output_format_template)
self.output_processors = ListParser()

def format_instructions(self) -> str:
return self.list_output_format_prompt()

def call(self, input: str) -> list:
return self.output_processors(input)
17 changes: 9 additions & 8 deletions use_cases/yaml_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from core.component import Component
from core.generator import Generator
from components.api_client import GroqAPIClient
from prompts.outputs import YAMLOutputParser
from components.api_client import GroqAPIClient, OpenAIClient
from prompts.outputs import YAMLOutputParser, ListOutputParser

import utils.setup_env

Expand All @@ -21,16 +21,15 @@ def __init__(self):
yaml_parser = YAMLOutputParser(data_class_for_yaml=Joke)
self.generator = Generator(
model_client=GroqAPIClient,
model_kwargs={"model": "llama3-8b-8192"},
model_kwargs={"model": "llama3-8b-8192", "temperature": 1.0},
preset_prompt_kwargs={
"task_desc_str": "Answer user query. "
+ yaml_parser.format_instructions()
"output_format_str": yaml_parser.format_instructions()
},
output_processors=yaml_parser,
)

def call(self, query: str) -> str:
return self.generator.call(input=query)
def call(self, query: str, model_kwargs: dict = {}) -> dict:
return self.generator.call(input=query, model_kwargs=model_kwargs)


if __name__ == "__main__":
Expand All @@ -39,4 +38,6 @@ def call(self, query: str) -> str:
print("show the system prompt")
joke_generator.generator.print_prompt()
print("Answer:")
print(joke_generator.call("Tell me a joke."))
answer = joke_generator.call("Tell me two jokes.", model_kwargs={"temperature": 1})
print(answer)
print(f"typeof answer: {type(answer)}")

0 comments on commit 8c3b6ae

Please sign in to comment.