diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..eeba78dc7 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,7 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function -from outlines.prompts import prompt +from outlines.prompts import Prompt, prompt __all__ = [ "clear_cache", @@ -15,6 +16,7 @@ "get_cache", "Function", "prompt", + "Prompt", "vectorize", "grammars", ] diff --git a/outlines/prompts.py b/outlines/prompts.py index a7824451a..f176bbc16 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -1,12 +1,13 @@ import functools import inspect import json +import os import re import textwrap from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Type, cast -from jinja2 import Environment, StrictUndefined +from jinja2 import Environment, FileSystemLoader, StrictUndefined, meta from pydantic import BaseModel @@ -41,6 +42,33 @@ def __call__(self, *args, **kwargs) -> str: def __str__(self): return self.template + @classmethod + def from_file(cls, filename: str): + """Create a Prompt instance from a file containing a Jinja template.""" + with open(filename) as file: + template = file.read() + + # Determine the directory of the file + file_directory = os.path.dirname(os.path.abspath(filename)) + + # Create a Jinja environment to parse the template + env = Environment( + loader=FileSystemLoader(file_directory), + trim_blocks=True, + lstrip_blocks=True, + ) + parsed_content = env.parse(template) + variables = meta.find_undeclared_variables(parsed_content) + + # Create a signature with the variables as parameters + parameters = [ + inspect.Parameter(var, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for var in variables + ] + signature = inspect.Signature(parameters) + + return cls(template, signature) + def prompt(fn: Callable) -> Prompt: """Decorate a function that contains a prompt template. diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a0433c0e5..2200acc66 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -1,10 +1,12 @@ +import os +import tempfile from typing import Dict, List import pytest from pydantic import BaseModel, Field import outlines -from outlines.prompts import render +from outlines.prompts import Prompt, render def test_render(): @@ -312,3 +314,59 @@ def args_prompt(fn): args_prompt(with_all) == "args: x1, y1, z1, x2: bool, y2: str, z2: Dict[int, List[str]], x3=True, y3='Hi', z3={4: ['I', 'love', 'outlines']}, x4: bool = True, y4: str = 'Hi', z4: Dict[int, List[str]] = {4: ['I', 'love', 'outlines']}" ) + + +@pytest.fixture +def temp_prompt_file(): + # Create a temporary directory + with tempfile.TemporaryDirectory() as test_dir: + # Create a temporary prompt file + prompt_file_path = os.path.join(test_dir, "prompt.txt") + with open(prompt_file_path, "w") as f: + f.write( + """Here is a prompt with examples: + + {% for example in examples %} + - Q: {{ example.question }} + - A: {{ example.answer }} + {% endfor %} + + Now please answer the following question: + + Q: {{ question }} + A: + """ + ) + yield prompt_file_path + + +def test_prompt_from_file(temp_prompt_file): + # Create a Prompt instance from the file + template = Prompt.from_file(temp_prompt_file) + + # Define example data + examples = [ + {"question": "What is the capital of France?", "answer": "Paris"}, + {"question": "What is 2 + 2?", "answer": "4"}, + ] + question = "What is the Earth's diameter?" + + # Render the template + rendered_prompt = template(examples=examples, question=question) + + # Expected output + expected_output = """Here is a prompt with examples: + + - Q: What is the capital of France? + - A: Paris + - Q: What is 2 + 2? + - A: 4 + + Now please answer the following question: + + Q: What is the Earth's diameter? + A: + """ + + # Assert the rendered prompt matches the expected output + assert rendered_prompt.strip() == expected_output.strip()