Skip to content

Commit

Permalink
Add from_file class method to the Prompt object
Browse files Browse the repository at this point in the history
Fix #1345
  • Loading branch information
yvan-sraka committed Jan 6, 2025
1 parent 3cc399d commit f75319f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 28 deletions.
4 changes: 3 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,14 +8,15 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt
from outlines.prompts import Prompt, prompt

__all__ = [
"clear_cache",
"disable_cache",
"get_cache",
"Function",
"prompt",
"Prompt",
"vectorize",
"grammars",
]
82 changes: 56 additions & 26 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import functools
import inspect
import json
import os
import re
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, cast

from jinja2 import Environment, StrictUndefined
from pydantic import BaseModel
import jinja2
import pydantic


@dataclass
Expand All @@ -19,12 +21,12 @@ class Prompt:
"""

template: str
signature: inspect.Signature
template: jinja2.Template
signature: Optional[inspect.Signature]

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
self.jinja_environment = create_jinja_template(self.template)
if self.signature is not None:
self.parameters: List[str] = list(self.signature.parameters.keys())

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -34,12 +36,23 @@ def __call__(self, *args, **kwargs) -> str:
The rendered template as a Python ``str``.
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.jinja_environment.render(**bound_arguments.arguments)
if self.signature is not None:
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.template.render(**bound_arguments.arguments)
else:
return self.template.render(**kwargs)

@classmethod
def from_file(cls, path: Path):
"""Create a Prompt instance from a file containing a Jinja template.
def __str__(self):
return self.template
Note: This method does not allow to include and inheritance to reference files
that are outside the folder or subfolders of the file given to `from_file`.
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(template_from_file(path), None)


def prompt(fn: Callable) -> Prompt:
Expand Down Expand Up @@ -87,12 +100,12 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)
template = template_from_str(cast(str, docstring))

return Prompt(template, signature)


def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
def render(template: jinja2.Template, **values: Optional[Dict[str, Any]]) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
This function removes extra whitespaces and linebreaks from templates to
Expand Down Expand Up @@ -174,7 +187,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
Parameters
----------
template
A string that contains a template written with the Jinja2 syntax.
A Jinja2 template.
**values
Map from the variables in the template to their value.
Expand All @@ -183,17 +196,16 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
A string that contains the rendered template.
"""
jinja_template = create_jinja_template(template)
return jinja_template.render(**values)
return template.render(**values)


def create_jinja_template(template: str):
def template_from_str(content: str) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(template)
cleaned_template = inspect.cleandoc(content)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = template.replace(" ", "").endswith("\n\n")
ends_with_linebreak = content.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

Expand All @@ -202,11 +214,30 @@ def create_jinja_template(template: str):
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = Environment(
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

return env.from_string(cleaned_template)


def template_from_file(path: Path) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
undefined=jinja2.StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
Expand All @@ -215,8 +246,7 @@ def create_jinja_template(template: str):
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)
return jinja_template
return env.get_template(os.path.basename(path))


def get_fn_name(fn: Callable):
Expand Down Expand Up @@ -301,10 +331,10 @@ def get_schema_dict(model: Dict):
return json.dumps(model, indent=2)


@get_schema.register(type(BaseModel))
def get_schema_pydantic(model: Type[BaseModel]):
@get_schema.register(type(pydantic.BaseModel))
def get_schema_pydantic(model: Type[pydantic.BaseModel]):
"""Return the schema of a Pydantic model."""
if not type(model) == type(BaseModel):
if not isinstance(model, type(pydantic.BaseModel)):
raise TypeError("The `schema` filter only applies to Pydantic models.")

if hasattr(model, "model_json_schema"):
Expand Down
68 changes: 67 additions & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import tempfile
from typing import Dict, List

import pytest
from pydantic import BaseModel, Field

import outlines
from outlines.prompts import render
from outlines.prompts import Prompt, render


def test_render():
Expand Down Expand Up @@ -312,3 +314,67 @@ def args_prompt(fn):
args_prompt(with_all)
== "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}"
)


@pytest.fixture
def temp_prompt_file():
test_dir = tempfile.mkdtemp()

base_template_path = os.path.join(test_dir, "base_template.txt")
with open(base_template_path, "w") as f:
f.write(
"""{% block content %}{% endblock %}
"""
)

include_file_path = os.path.join(test_dir, "include.txt")
with open(include_file_path, "w") as f:
f.write(
"""{% for example in examples %}
- Q: {{ example.question }}
- A: {{ example.answer }}
{% endfor %}
"""
)

prompt_file_path = os.path.join(test_dir, "prompt.txt")
with open(prompt_file_path, "w") as f:
f.write(
"""{% extends "base_template.txt" %}
{% block content %}
Here is a prompt with examples:
{% include "include.txt" %}
Now please answer the following question:
Q: {{ question }}
A:
{% endblock %}
"""
)
yield prompt_file_path


def test_prompt_from_file(temp_prompt_file):
prompt = Prompt.from_file(temp_prompt_file)
examples = [
{"question": "What is the capital of France?", "answer": "Paris"},
{"question": "What is 2 + 2?", "answer": "4"},
]
question = "What is the Earth's diameter?"
rendered = prompt(examples=examples, question=question)
expected = """Here is a prompt with examples:
- Q: What is the capital of France?
- A: Paris
- Q: What is 2 + 2?
- A: 4
Now please answer the following question:
Q: What is the Earth's diameter?
A:
"""
assert rendered.strip() == expected.strip()

0 comments on commit f75319f

Please sign in to comment.