diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..eeba78dc7 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,7 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function -from outlines.prompts import prompt +from outlines.prompts import Prompt, prompt __all__ = [ "clear_cache", @@ -15,6 +16,7 @@ "get_cache", "Function", "prompt", + "Prompt", "vectorize", "grammars", ] diff --git a/outlines/prompts.py b/outlines/prompts.py index a7824451a..66e723205 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -1,12 +1,14 @@ import functools import inspect import json +import os import re import textwrap from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Type, cast +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, cast -from jinja2 import Environment, StrictUndefined +import jinja2 from pydantic import BaseModel @@ -19,12 +21,8 @@ class Prompt: """ - template: str - signature: inspect.Signature - - def __post_init__(self): - self.parameters: List[str] = list(self.signature.parameters.keys()) - self.jinja_environment = create_jinja_template(self.template) + template: jinja2.Template + signature: Optional[inspect.Signature] def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -34,12 +32,93 @@ def __call__(self, *args, **kwargs) -> str: The rendered template as a Python ``str``. """ - bound_arguments = self.signature.bind(*args, **kwargs) - bound_arguments.apply_defaults() - return self.jinja_environment.render(**bound_arguments.arguments) + if self.signature is not None: + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return self.template.render(**bound_arguments.arguments) + else: + return self.template.render(**kwargs) - def __str__(self): - return self.template + @classmethod + def from_str(cls, content: str): + """ + Create an instance of the class from a string. + + Parameters + ---------- + content : str + The string content to be converted into a template. + + Returns + ------- + An instance of the class with the provided content as a template. + """ + return cls(cls._template_from_str(content), None) + + @classmethod + def from_file(cls, path: Path): + """ + Create a Prompt instance from a file containing a Jinja template. + + Note: This method does not allow to include and inheritance to reference files + that are outside the folder or subfolders of the file given to `from_file`. + + Parameters + ---------- + path : Path + The path to the file containing the Jinja template. + + Returns + ------- + Prompt + An instance of the Prompt class with the template loaded from the file. + """ + # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is + # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) + return cls(cls._template_from_file(path), None) + + @classmethod + def _template_from_str(_, content: str) -> jinja2.Template: + # Dedent, and remove extra linebreak + cleaned_template = inspect.cleandoc(content) + + # Add linebreak if there were any extra linebreaks that + # `cleandoc` would have removed + ends_with_linebreak = content.replace(" ", "").endswith("\n\n") + if ends_with_linebreak: + cleaned_template += "\n" + + # Remove extra whitespaces, except those that immediately follow a newline symbol. + # This is necessary to avoid introducing whitespaces after backslash `\` characters + # used to continue to the next line without linebreak. + cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + env.filters["name"] = get_fn_name + env.filters["description"] = get_fn_description + env.filters["source"] = get_fn_source + env.filters["signature"] = get_fn_signature + env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args + + return env.from_string(cleaned_template) + + @classmethod + def _template_from_file(_, path: Path) -> jinja2.Template: + file_directory = os.path.dirname(os.path.abspath(path)) + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(file_directory), + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + return env.get_template(os.path.basename(path)) def prompt(fn: Callable) -> Prompt: @@ -87,138 +166,11 @@ def prompt(fn: Callable) -> Prompt: if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = cast(str, docstring) + template = Prompt._template_from_str(cast(str, docstring)) return Prompt(template, signature) -def render(template: str, **values: Optional[Dict[str, Any]]) -> str: - r"""Parse a Jinaj2 template and translate it into an Outlines graph. - - This function removes extra whitespaces and linebreaks from templates to - allow users to enter prompts more naturally than if they used Python's - constructs directly. See the examples for a detailed explanation. - - Examples - -------- - - Outlines follow Jinja2's syntax - - >>> import outlines - >>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") - I like tomatoes and tennis - - If the first line of the template is empty, `render` removes it - - >>> from outlines import render - >>> - >>> tpl = ''' - ... A new string''' - >>> tpl - ... '\nA new string' - >>> render(tpl) - ... 'a new string' - - Similarly, `render` ignores linebreaks introduced by placing the closing quotes - underneath the text: - - >>> tpl = ''' - ... A new string - ... ''' - >>> tpl - ... '\nA new string\n' - >>> render(tpl) - ... 'A new string' - - If you want to insert a linebreak at the end of the rendered template, you will - need to leave an empty line at the end of the template: - - >>> tpl = ''' - ... A new string - ... - ... ''' - >>> tpl - ... '\nA new string\n\n' - >>> render(tpl) - ... 'A new string\n' - - `render` removes the identation in docstrings. This is particularly important - when using prompt functions - - >>> tpl = ''' - ... a string - ... and another string''' - >>> tpl - ... '\n a string\n and another string' - >>> render(tpl) - ... 'a string\nand another string' - - The indentation of the first line is assumed to be the same as the second line's - - >>> tpl = '''a string - ... and another''' - >>> tpl - ... 'a string\n and another' - >>> render(tpl) - ... 'a string\nand another' - - To get a different indentation for the first and the second line, we can start the - prompt on the string's second line: - - >>> tpl = ''' - ... First line - ... Second line''' - >>> render(tpl) - ... 'First Line\n Second Line' - - Parameters - ---------- - template - A string that contains a template written with the Jinja2 syntax. - **values - Map from the variables in the template to their value. - - Returns - ------- - A string that contains the rendered template. - - """ - jinja_template = create_jinja_template(template) - return jinja_template.render(**values) - - -def create_jinja_template(template: str): - # Dedent, and remove extra linebreak - cleaned_template = inspect.cleandoc(template) - - # Add linebreak if there were any extra linebreaks that - # `cleandoc` would have removed - ends_with_linebreak = template.replace(" ", "").endswith("\n\n") - if ends_with_linebreak: - cleaned_template += "\n" - - # Remove extra whitespaces, except those that immediately follow a newline symbol. - # This is necessary to avoid introducing whitespaces after backslash `\` characters - # used to continue to the next line without linebreak. - cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - - env = Environment( - trim_blocks=True, - lstrip_blocks=True, - keep_trailing_newline=True, - undefined=StrictUndefined, - ) - env.filters["name"] = get_fn_name - env.filters["description"] = get_fn_description - env.filters["source"] = get_fn_source - env.filters["signature"] = get_fn_signature - env.filters["schema"] = get_schema - env.filters["args"] = get_fn_args - - jinja_template = env.from_string(cleaned_template) - return jinja_template - - def get_fn_name(fn: Callable): """Returns the name of a callable.""" if not callable(fn): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a0433c0e5..4cc4d8ff1 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,10 +1,17 @@ +import os +import tempfile from typing import Dict, List import pytest from pydantic import BaseModel, Field import outlines -from outlines.prompts import render +from outlines.prompts import Prompt + + +def render(content: str, **kwargs): + template = Prompt._template_from_str(content) + return template.render(kwargs) def test_render(): @@ -110,8 +117,7 @@ def test_prompt_basic(): def test_tpl(variable): """{{variable}} test""" - assert test_tpl.template == "{{variable}} test" - assert test_tpl.parameters == ["variable"] + assert list(test_tpl.signature.parameters) == ["variable"] with pytest.raises(TypeError): test_tpl(v="test") @@ -126,6 +132,8 @@ def test_tpl(variable): def test_single_quote_tpl(variable): "${variable} test" + assert list(test_single_quote_tpl.signature.parameters) == ["variable"] + p = test_tpl("test") assert p == "test test" @@ -135,8 +143,7 @@ def test_prompt_kwargs(): def test_kwarg_tpl(var, other_var="other"): """{{var}} and {{other_var}}""" - assert test_kwarg_tpl.template == "{{var}} and {{other_var}}" - assert test_kwarg_tpl.parameters == ["var", "other_var"] + assert list(test_kwarg_tpl.signature.parameters) == ["var", "other_var"] p = test_kwarg_tpl("test") assert p == "test and other" @@ -312,3 +319,87 @@ def args_prompt(fn): args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" ) + + +@pytest.fixture +def temp_prompt_file(): + test_dir = tempfile.mkdtemp() + + base_template_path = os.path.join(test_dir, "base_template.txt") + with open(base_template_path, "w") as f: + f.write( + """{% block content %}{% endblock %} +""" + ) + + include_file_path = os.path.join(test_dir, "include.txt") + with open(include_file_path, "w") as f: + f.write( + """{% for example in examples %} +- Q: {{ example.question }} +- A: {{ example.answer }} +{% endfor %} +""" + ) + + prompt_file_path = os.path.join(test_dir, "prompt.txt") + with open(prompt_file_path, "w") as f: + f.write( + """{% extends "base_template.txt" %} + +{% block content %} +Here is a prompt with examples: + +{% include "include.txt" %} + +Now please answer the following question: + +Q: {{ question }} +A: +{% endblock %} +""" + ) + yield prompt_file_path + + +def test_prompt_from_file(temp_prompt_file): + prompt = Prompt.from_file(temp_prompt_file) + assert prompt.signature is None + examples = [ + {"question": "What is the capital of France?", "answer": "Paris"}, + {"question": "What is 2 + 2?", "answer": "4"}, + ] + question = "What is the Earth's diameter?" + rendered = prompt(examples=examples, question=question) + expected = """Here is a prompt with examples: + +- Q: What is the capital of France? +- A: Paris +- Q: What is 2 + 2? +- A: 4 + +Now please answer the following question: + +Q: What is the Earth's diameter? +A: +""" + assert rendered.strip() == expected.strip() + + +def test_prompt_from_str(): + content = """ + Hello, {{ name }}! + """ + prompt = Prompt.from_str(content) + assert prompt.signature is None + assert prompt(name="World") == "Hello, World!" + + +def test_template_from_str_with_extra_linebreaks(): + content = """ + Hello, {{ name }}! + + + """ + template = Prompt._template_from_str(content) + assert template.render(name="World") == "Hello, World!\n"