From 84aa8b0effedd1aa89b3f36d9704c57f804409f8 Mon Sep 17 00:00:00 2001 From: Yvan Sraka Date: Thu, 2 Jan 2025 14:50:00 +0100 Subject: [PATCH] Add `from_file` class method to the `Prompt` object --- outlines/__init__.py | 4 +- outlines/prompts.py | 166 +++++++++++++++++++++++------------ tests/test_prompts.py | 195 ++++++++++++++++++++---------------------- 3 files changed, 205 insertions(+), 160 deletions(-) diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..eeba78dc7 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,7 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function -from outlines.prompts import prompt +from outlines.prompts import Prompt, prompt __all__ = [ "clear_cache", @@ -15,6 +16,7 @@ "get_cache", "Function", "prompt", + "Prompt", "vectorize", "grammars", ] diff --git a/outlines/prompts.py b/outlines/prompts.py index a7824451a..9243f1c3a 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -1,13 +1,16 @@ import functools import inspect import json +import os import re import textwrap +import warnings from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type, cast +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, cast -from jinja2 import Environment, StrictUndefined -from pydantic import BaseModel +import jinja2 +import pydantic @dataclass @@ -19,12 +22,8 @@ class Prompt: """ - template: str - signature: inspect.Signature - - def __post_init__(self): - self.parameters: List[str] = list(self.signature.parameters.keys()) - self.jinja_environment = create_jinja_template(self.template) + template: jinja2.Template + signature: Optional[inspect.Signature] def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -34,12 +33,93 @@ def __call__(self, *args, **kwargs) -> str: The rendered template as a Python ``str``. """ - bound_arguments = self.signature.bind(*args, **kwargs) - bound_arguments.apply_defaults() - return self.jinja_environment.render(**bound_arguments.arguments) + if self.signature is not None: + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return self.template.render(**bound_arguments.arguments) + else: + return self.template.render(**kwargs) + + @classmethod + def from_str(cls, content: str): + """ + Create an instance of the class from a string. + + Parameters + ---------- + content : str + The string content to be converted into a template. + + Returns + ------- + An instance of the class with the provided content as a template. + """ + return cls(cls._template_from_str(content), None) + + @classmethod + def from_file(cls, path: Path): + """ + Create a Prompt instance from a file containing a Jinja template. - def __str__(self): - return self.template + Note: This method does not allow to include and inheritance to reference files + that are outside the folder or subfolders of the file given to `from_file`. + + Parameters + ---------- + path : Path + The path to the file containing the Jinja template. + + Returns + ------- + Prompt + An instance of the Prompt class with the template loaded from the file. + """ + # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is + # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) + return cls(cls._template_from_file(path), None) + + @classmethod + def _template_from_str(_, content: str) -> jinja2.Template: + # Dedent, and remove extra linebreak + cleaned_template = inspect.cleandoc(content) + + # Add linebreak if there were any extra linebreaks that + # `cleandoc` would have removed + ends_with_linebreak = content.replace(" ", "").endswith("\n\n") + if ends_with_linebreak: + cleaned_template += "\n" + + # Remove extra whitespaces, except those that immediately follow a newline symbol. + # This is necessary to avoid introducing whitespaces after backslash `\` characters + # used to continue to the next line without linebreak. + cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + env.filters["name"] = get_fn_name + env.filters["description"] = get_fn_description + env.filters["source"] = get_fn_source + env.filters["signature"] = get_fn_signature + env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args + + return env.from_string(cleaned_template) + + @classmethod + def _template_from_file(_, path: Path) -> jinja2.Template: + file_directory = os.path.dirname(os.path.abspath(path)) + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(file_directory), + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + return env.get_template(os.path.basename(path)) def prompt(fn: Callable) -> Prompt: @@ -87,14 +167,18 @@ def prompt(fn: Callable) -> Prompt: if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = cast(str, docstring) + template = Prompt._template_from_str(cast(str, docstring)) return Prompt(template, signature) -def render(template: str, **values: Optional[Dict[str, Any]]) -> str: +def render( + template: str, **values: Optional[Dict[str, Any]] +) -> str: # pragma: no cover r"""Parse a Jinaj2 template and translate it into an Outlines graph. + [DEPRECATED] Using `render(str)` is deprecated. + This function removes extra whitespaces and linebreaks from templates to allow users to enter prompts more naturally than if they used Python's constructs directly. See the examples for a detailed explanation. @@ -105,12 +189,12 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: Outlines follow Jinja2's syntax >>> import outlines - >>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") + >>> outline = outlines.prompts.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") I like tomatoes and tennis If the first line of the template is empty, `render` removes it - >>> from outlines import render + >>> from outlines.prompts import render >>> >>> tpl = ''' ... A new string''' @@ -174,7 +258,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: Parameters ---------- template - A string that contains a template written with the Jinja2 syntax. + A Jinja2 template. **values Map from the variables in the template to their value. @@ -183,40 +267,12 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: A string that contains the rendered template. """ - jinja_template = create_jinja_template(template) - return jinja_template.render(**values) - - -def create_jinja_template(template: str): - # Dedent, and remove extra linebreak - cleaned_template = inspect.cleandoc(template) - - # Add linebreak if there were any extra linebreaks that - # `cleandoc` would have removed - ends_with_linebreak = template.replace(" ", "").endswith("\n\n") - if ends_with_linebreak: - cleaned_template += "\n" - - # Remove extra whitespaces, except those that immediately follow a newline symbol. - # This is necessary to avoid introducing whitespaces after backslash `\` characters - # used to continue to the next line without linebreak. - cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - - env = Environment( - trim_blocks=True, - lstrip_blocks=True, - keep_trailing_newline=True, - undefined=StrictUndefined, + warnings.warn( + "Using `render(str)` is deprecated.", + DeprecationWarning, ) - env.filters["name"] = get_fn_name - env.filters["description"] = get_fn_description - env.filters["source"] = get_fn_source - env.filters["signature"] = get_fn_signature - env.filters["schema"] = get_schema - env.filters["args"] = get_fn_args - - jinja_template = env.from_string(cleaned_template) - return jinja_template + template = Prompt._template_from_str(template) + return template.render(**values) def get_fn_name(fn: Callable): @@ -301,10 +357,10 @@ def get_schema_dict(model: Dict): return json.dumps(model, indent=2) -@get_schema.register(type(BaseModel)) -def get_schema_pydantic(model: Type[BaseModel]): +@get_schema.register(type(pydantic.BaseModel)) +def get_schema_pydantic(model: Type[pydantic.BaseModel]): """Return the schema of a Pydantic model.""" - if not type(model) == type(BaseModel): + if not isinstance(model, type(pydantic.BaseModel)): raise TypeError("The `schema` filter only applies to Pydantic models.") if hasattr(model, "model_json_schema"): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a0433c0e5..6e22ea68d 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,108 +1,12 @@ +import os +import tempfile from typing import Dict, List import pytest from pydantic import BaseModel, Field import outlines -from outlines.prompts import render - - -def test_render(): - tpl = """ - A test string""" - assert render(tpl) == "A test string" - - tpl = """ - A test string - """ - assert render(tpl) == "A test string" - - tpl = """ - A test - Another test - """ - assert render(tpl) == "A test\nAnother test" - - tpl = """A test - Another test - """ - assert render(tpl) == "A test\nAnother test" - - tpl = """ - A test line - An indented line - """ - assert render(tpl) == "A test line\n An indented line" - - tpl = """ - A test line - An indented line - - """ - assert render(tpl) == "A test line\n An indented line\n" - - -def test_render_escaped_linebreak(): - tpl = """ - A long test \ - that we break \ - in several lines - """ - assert render(tpl) == "A long test that we break in several lines" - - tpl = """ - Break in \ - several lines \ - But respect the indentation - on line breaks. - And after everything \ - Goes back to normal - """ - assert ( - render(tpl) - == "Break in several lines But respect the indentation\n on line breaks.\nAnd after everything Goes back to normal" - ) - - -def test_render_jinja(): - """Make sure that we can use basic Jinja2 syntax, and give examples - of how we can use it for basic use cases. - """ - - # Notice the newline after the end of the loop - examples = ["one", "two"] - prompt = render( - """ - {% for e in examples %} - Example: {{e}} - {% endfor -%}""", - examples=examples, - ) - assert prompt == "Example: one\nExample: two\n" - - # We can remove the newline by cloing with -%} - examples = ["one", "two"] - prompt = render( - """ - {% for e in examples %} - Example: {{e}} - {% endfor -%} - - Final""", - examples=examples, - ) - assert prompt == "Example: one\nExample: two\nFinal" - - # Same for conditionals - tpl = """ - {% if is_true %} - true - {% endif -%} - - final - """ - assert render(tpl, is_true=True) == "true\nfinal" - assert render(tpl, is_true=False) == "final" +from outlines.prompts import Prompt def test_prompt_basic(): @@ -110,8 +14,7 @@ def test_prompt_basic(): def test_tpl(variable): """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" - assert test_tpl.parameters == ["variable"] + assert list(test_tpl.signature.parameters) == ["variable"] with pytest.raises(TypeError): test_tpl(v="test") @@ -126,6 +29,8 @@ def test_tpl(variable): def test_single_quote_tpl(variable): "${variable} test" + assert list(test_single_quote_tpl.signature.parameters) == ["variable"] + p = test_tpl("test") assert p == "test test" @@ -135,8 +40,7 @@ def test_prompt_kwargs(): def test_kwarg_tpl(var, other_var="other"): """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" - assert test_kwarg_tpl.parameters == ["var", "other_var"] + assert list(test_kwarg_tpl.signature.parameters) == ["var", "other_var"] p = test_kwarg_tpl("test") assert p == "test and other" @@ -145,7 +49,6 @@ def test_kwarg_tpl(var, other_var="other"): assert p == "test and kwarg" p = test_kwarg_tpl("test", "test") - assert p == "test and test" def test_no_prompt(): @@ -312,3 +215,87 @@ def args_prompt(fn): args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" ) + + +@pytest.fixture +def temp_prompt_file(): + test_dir = tempfile.mkdtemp() + + base_template_path = os.path.join(test_dir, "base_template.txt") + with open(base_template_path, "w") as f: + f.write( + """{% block content %}{% endblock %} +""" + ) + + include_file_path = os.path.join(test_dir, "include.txt") + with open(include_file_path, "w") as f: + f.write( + """{% for example in examples %} +- Q: {{ example.question }} +- A: {{ example.answer }} +{% endfor %} +""" + ) + + prompt_file_path = os.path.join(test_dir, "prompt.txt") + with open(prompt_file_path, "w") as f: + f.write( + """{% extends "base_template.txt" %} + +{% block content %} +Here is a prompt with examples: + +{% include "include.txt" %} + +Now please answer the following question: + +Q: {{ question }} +A: +{% endblock %} +""" + ) + yield prompt_file_path + + +def test_prompt_from_file(temp_prompt_file): + prompt = Prompt.from_file(temp_prompt_file) + assert prompt.signature is None + examples = [ + {"question": "What is the capital of France?", "answer": "Paris"}, + {"question": "What is 2 + 2?", "answer": "4"}, + ] + question = "What is the Earth's diameter?" + rendered = prompt(examples=examples, question=question) + expected = """Here is a prompt with examples: + +- Q: What is the capital of France? +- A: Paris +- Q: What is 2 + 2? +- A: 4 + +Now please answer the following question: + +Q: What is the Earth's diameter? +A: +""" + assert rendered.strip() == expected.strip() + + +def test_prompt_from_str(): + content = """ + Hello, {{ name }}! + """ + prompt = Prompt.from_str(content) + assert prompt.signature is None + assert prompt(name="World") == "Hello, World!" + + +def test_template_from_str_with_extra_linebreaks(): + content = """ + Hello, {{ name }}! + + + """ + template = Prompt._template_from_str(content) + assert template.render(name="World") == "Hello, World!\n"