Skip to content

Commit

Permalink
Add from_file class method to the Prompt object
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 8, 2025
1 parent 3cc399d commit c03a2da
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 250 deletions.
4 changes: 3 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,14 +8,15 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt
from outlines.prompts import Prompt, prompt

__all__ = [
"clear_cache",
"disable_cache",
"get_cache",
"Function",
"prompt",
"Prompt",
"vectorize",
"grammars",
]
242 changes: 97 additions & 145 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import functools
import inspect
import json
import os
import re
import textwrap
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, cast
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, cast

from jinja2 import Environment, StrictUndefined
from pydantic import BaseModel
import jinja2
import pydantic


@dataclass
Expand All @@ -19,12 +21,8 @@ class Prompt:
"""

template: str
signature: inspect.Signature

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
self.jinja_environment = create_jinja_template(self.template)
template: jinja2.Template
signature: Optional[inspect.Signature]

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -34,12 +32,93 @@ def __call__(self, *args, **kwargs) -> str:
The rendered template as a Python ``str``.
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.jinja_environment.render(**bound_arguments.arguments)
if self.signature is not None:
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return self.template.render(**bound_arguments.arguments)
else:
return self.template.render(**kwargs)

def __str__(self):
return self.template
@classmethod
def from_str(cls, content: str):
"""
Create an instance of the class from a string.
Parameters
----------
content : str
The string content to be converted into a template.
Returns
-------
An instance of the class with the provided content as a template.
"""
return cls(cls._template_from_str(content), None)

@classmethod
def from_file(cls, path: Path):
"""
Create a Prompt instance from a file containing a Jinja template.
Note: This method does not allow to include and inheritance to reference files
that are outside the folder or subfolders of the file given to `from_file`.
Parameters
----------
path : Path
The path to the file containing the Jinja template.
Returns
-------
Prompt
An instance of the Prompt class with the template loaded from the file.
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(cls._template_from_file(path), None)

@classmethod
def _template_from_str(_, content: str) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(content)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = content.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

# Remove extra whitespaces, except those that immediately follow a newline symbol.
# This is necessary to avoid introducing whitespaces after backslash `\` characters
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

return env.from_string(cleaned_template)

@classmethod
def _template_from_file(_, path: Path) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
return env.get_template(os.path.basename(path))


def prompt(fn: Callable) -> Prompt:
Expand Down Expand Up @@ -87,138 +166,11 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = cast(str, docstring)
template = Prompt._template_from_str(cast(str, docstring))

return Prompt(template, signature)


def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
r"""Parse a Jinaj2 template and translate it into an Outlines graph.
This function removes extra whitespaces and linebreaks from templates to
allow users to enter prompts more naturally than if they used Python's
constructs directly. See the examples for a detailed explanation.
Examples
--------
Outlines follow Jinja2's syntax
>>> import outlines
>>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis")
I like tomatoes and tennis
If the first line of the template is empty, `render` removes it
>>> from outlines import render
>>>
>>> tpl = '''
... A new string'''
>>> tpl
... '\nA new string'
>>> render(tpl)
... 'a new string'
Similarly, `render` ignores linebreaks introduced by placing the closing quotes
underneath the text:
>>> tpl = '''
... A new string
... '''
>>> tpl
... '\nA new string\n'
>>> render(tpl)
... 'A new string'
If you want to insert a linebreak at the end of the rendered template, you will
need to leave an empty line at the end of the template:
>>> tpl = '''
... A new string
...
... '''
>>> tpl
... '\nA new string\n\n'
>>> render(tpl)
... 'A new string\n'
`render` removes the identation in docstrings. This is particularly important
when using prompt functions
>>> tpl = '''
... a string
... and another string'''
>>> tpl
... '\n a string\n and another string'
>>> render(tpl)
... 'a string\nand another string'
The indentation of the first line is assumed to be the same as the second line's
>>> tpl = '''a string
... and another'''
>>> tpl
... 'a string\n and another'
>>> render(tpl)
... 'a string\nand another'
To get a different indentation for the first and the second line, we can start the
prompt on the string's second line:
>>> tpl = '''
... First line
... Second line'''
>>> render(tpl)
... 'First Line\n Second Line'
Parameters
----------
template
A string that contains a template written with the Jinja2 syntax.
**values
Map from the variables in the template to their value.
Returns
-------
A string that contains the rendered template.
"""
jinja_template = create_jinja_template(template)
return jinja_template.render(**values)


def create_jinja_template(template: str):
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(template)

# Add linebreak if there were any extra linebreaks that
# `cleandoc` would have removed
ends_with_linebreak = template.replace(" ", "").endswith("\n\n")
if ends_with_linebreak:
cleaned_template += "\n"

# Remove extra whitespaces, except those that immediately follow a newline symbol.
# This is necessary to avoid introducing whitespaces after backslash `\` characters
# used to continue to the next line without linebreak.
cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template)

env = Environment(
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=StrictUndefined,
)
env.filters["name"] = get_fn_name
env.filters["description"] = get_fn_description
env.filters["source"] = get_fn_source
env.filters["signature"] = get_fn_signature
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)
return jinja_template


def get_fn_name(fn: Callable):
"""Returns the name of a callable."""
if not callable(fn):
Expand Down Expand Up @@ -301,10 +253,10 @@ def get_schema_dict(model: Dict):
return json.dumps(model, indent=2)


@get_schema.register(type(BaseModel))
def get_schema_pydantic(model: Type[BaseModel]):
@get_schema.register(type(pydantic.BaseModel))
def get_schema_pydantic(model: Type[pydantic.BaseModel]):
"""Return the schema of a Pydantic model."""
if not type(model) == type(BaseModel):
if not isinstance(model, type(pydantic.BaseModel)):
raise TypeError("The `schema` filter only applies to Pydantic models.")

if hasattr(model, "model_json_schema"):
Expand Down
Loading

0 comments on commit c03a2da

Please sign in to comment.