diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..c23f601ea 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,6 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function +from outlines.outline import Outline from outlines.prompts import prompt __all__ = [ @@ -17,4 +19,5 @@ "prompt", "vectorize", "grammars", + "Outline", ] diff --git a/outlines/outline.py b/outlines/outline.py new file mode 100644 index 000000000..85057418e --- /dev/null +++ b/outlines/outline.py @@ -0,0 +1,52 @@ +import ast +from dataclasses import dataclass + + +@dataclass +class Outline: + """ + Outline is a class that creates a callable object to generate responses + based on a given model and a prompt template. + + Args: + model: The model to be used for generating responses. + template (function): A function that takes arguments and returns a prompt string. + output_type: The expected output type of the generated response. + + Example: + from outlines import models + + model = models.transformers("gpt2") + + def template(a: int) -> str: + return f"What is 2 times {a}?" + + fn = Outline(model, template, int) + + result = fn(3) + print(result) # Expected output: 6 + """ + + def __init__(self, model, template, output_type): + self.model = model + self.template = template + self.output_type = output_type + + def __call__(self, *args): + # Generate the prompt using the template function + prompt = self.template(*args) + + # Generate the response using the model + response = self.model.generate(prompt) + + # Process the response to match the expected output type + try: + parsed_response = ast.literal_eval(response.strip()) + if isinstance(parsed_response, self.output_type): + return parsed_response + else: + raise ValueError( + f"Response type {type(parsed_response)} does not match expected type {self.output_type}" + ) + except (ValueError, SyntaxError): + raise ValueError(f"Unable to parse response: {response.strip()}") diff --git a/tests/test_outline.py b/tests/test_outline.py new file mode 100644 index 000000000..eb8983810 --- /dev/null +++ b/tests/test_outline.py @@ -0,0 +1,56 @@ +from unittest.mock import MagicMock + +import pytest + +from outlines.outline import Outline + + +def test_outline_int_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "6" + + # Define the template function + def template(a: int) -> str: + return f"What is 2 times {a}?" + + # Create an instance of Outline + fn = Outline(model, template, int) + + # Test the callable object + result = fn(3) + assert result == 6 + + +def test_outline_str_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "'Hello, world!'" + + # Define the template function + def template(a: int) -> str: + return f"Say hello {a} times" + + # Create an instance of Outline + fn = Outline(model, template, str) + + # Test the callable object + result = fn(1) + assert result == "Hello, world!" + + +def test_outline_invalid_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "not a number" + + # Define the template function + def template(a: int) -> str: + return f"What is 2 times {a}?" + + # Create an instance of Outline + fn = Outline(model, template, int) + + # Test the callable object with invalid output + with pytest.raises(ValueError): + fn(3)