Skip to content

Commit

Permalink
Add from_file class method to the Prompt object
Browse files Browse the repository at this point in the history
Fix #1345
  • Loading branch information
yvan-sraka committed Jan 2, 2025
1 parent 6a8612b commit b2605b3
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 3 deletions.
4 changes: 3 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,14 +8,15 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt
from outlines.prompts import Prompt, prompt

__all__ = [
"clear_cache",
"disable_cache",
"get_cache",
"Function",
"prompt",
"Prompt",
"vectorize",
"grammars",
]
30 changes: 29 additions & 1 deletion outlines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import functools
import inspect
import json
import os
import re
import textwrap
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Type, cast

from jinja2 import Environment, StrictUndefined
from jinja2 import Environment, FileSystemLoader, StrictUndefined, meta
from pydantic import BaseModel


Expand Down Expand Up @@ -41,6 +42,33 @@ def __call__(self, *args, **kwargs) -> str:
def __str__(self):
return self.template

@classmethod
def from_file(cls, filename: str):
"""Create a Prompt instance from a file containing a Jinja template."""
with open(filename) as file:
template = file.read()

# Determine the directory of the file
file_directory = os.path.dirname(os.path.abspath(filename))

# Create a Jinja environment to parse the template
env = Environment(
loader=FileSystemLoader(file_directory),
trim_blocks=True,
lstrip_blocks=True,
)
parsed_content = env.parse(template)
variables = meta.find_undeclared_variables(parsed_content)

# Create a signature with the variables as parameters
parameters = [
inspect.Parameter(var, inspect.Parameter.POSITIONAL_OR_KEYWORD)
for var in variables
]
signature = inspect.Signature(parameters)

return cls(template, signature)


def prompt(fn: Callable) -> Prompt:
"""Decorate a function that contains a prompt template.
Expand Down
60 changes: 59 additions & 1 deletion tests/test_prompts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import tempfile
from typing import Dict, List

import pytest
from pydantic import BaseModel, Field

import outlines
from outlines.prompts import render
from outlines.prompts import Prompt, render


def test_render():
Expand Down Expand Up @@ -312,3 +314,59 @@ def args_prompt(fn):
args_prompt(with_all)
== "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}"
)


@pytest.fixture
def temp_prompt_file():
# Create a temporary directory
with tempfile.TemporaryDirectory() as test_dir:
# Create a temporary prompt file
prompt_file_path = os.path.join(test_dir, "prompt.txt")
with open(prompt_file_path, "w") as f:
f.write(
"""Here is a prompt with examples:
{% for example in examples %}
- Q: {{ example.question }}
- A: {{ example.answer }}
{% endfor %}
Now please answer the following question:
Q: {{ question }}
A:
"""
)
yield prompt_file_path


def test_prompt_from_file(temp_prompt_file):
# Create a Prompt instance from the file
template = Prompt.from_file(temp_prompt_file)

# Define example data
examples = [
{"question": "What is the capital of France?", "answer": "Paris"},
{"question": "What is 2 + 2?", "answer": "4"},
]
question = "What is the Earth's diameter?"

# Render the template
rendered_prompt = template(examples=examples, question=question)

# Expected output
expected_output = """Here is a prompt with examples:
- Q: What is the capital of France?
- A: Paris
- Q: What is 2 + 2?
- A: 4
Now please answer the following question:
Q: What is the Earth's diameter?
A:
"""

# Assert the rendered prompt matches the expected output
assert rendered_prompt.strip() == expected_output.strip()

0 comments on commit b2605b3

Please sign in to comment.