From a62490716d6439edbac32d1dfddbd4041c0e29e3 Mon Sep 17 00:00:00 2001 From: Yvan Sraka Date: Wed, 8 Jan 2025 09:18:33 +0100 Subject: [PATCH] Add `Outline`s --- outlines/__init__.py | 3 ++ outlines/function.py | 13 ++++++ outlines/outline.py | 57 ++++++++++++++++++++++++++ tests/test_outline.py | 93 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 166 insertions(+) create mode 100644 outlines/outline.py create mode 100644 tests/test_outline.py diff --git a/outlines/__init__.py b/outlines/__init__.py index 307d2ba6f..c23f601ea 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -1,4 +1,5 @@ """Outlines is a Generative Model Programming Framework.""" + import outlines.generate import outlines.grammars import outlines.models @@ -7,6 +8,7 @@ from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache from outlines.function import Function +from outlines.outline import Outline from outlines.prompts import prompt __all__ = [ @@ -17,4 +19,5 @@ "prompt", "vectorize", "grammars", + "Outline", ] diff --git a/outlines/function.py b/outlines/function.py index 48577be8f..33b4a427d 100644 --- a/outlines/function.py +++ b/outlines/function.py @@ -1,4 +1,5 @@ import importlib.util +import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union @@ -11,6 +12,14 @@ from outlines.prompts import Prompt +# Display a deprecation warning +warnings.warn( + "The 'function' module is deprecated and will be removed in a future release.", + DeprecationWarning, + stacklevel=2, +) + + @dataclass class Function: """Represents an Outlines function. @@ -20,6 +29,10 @@ class Function: the function can be called with arguments that will be used to render the prompt template. + Note: + This class is part of the deprecated 'function' module and will be removed + in a future release. + """ prompt_template: "Prompt" diff --git a/outlines/outline.py b/outlines/outline.py new file mode 100644 index 000000000..9f9983022 --- /dev/null +++ b/outlines/outline.py @@ -0,0 +1,57 @@ +import ast +from dataclasses import dataclass + + +@dataclass +class Outline: + """ + Outline is a class that creates a callable object to generate responses + based on a given model, a prompt template (a function that returns a `str`) and an expected output type. + + Parameters + ---------- + model : object + The model to be used for generating responses. + template : function + A function that takes arguments and returns a prompt string. + output_type : type + The expected output type of the generated response. + + Examples + -------- + from outlines import models + + model = models.transformers("gpt2") + + def template(a: int) -> str: + return f"What is 2 times {a}?" + + fn = Outline(model, template, int) + + result = fn(3) + print(result) # Expected output: 6 + """ + + def __init__(self, model, template, output_type): + self.model = model + self.template = template + self.output_type = output_type + + def __call__(self, *args): + # Generate the prompt using the template function + prompt = self.template(*args) + + # Generate the response using the model + response = self.model.generate(prompt) + + # Process the response to match the expected output type + try: + parsed_response = ast.literal_eval(response.strip()) + if isinstance(parsed_response, self.output_type): + return parsed_response + else: + raise ValueError( + f"Response type {type(parsed_response)} does not match expected type {self.output_type}" + ) + except (ValueError, SyntaxError): + raise ValueError(f"Unable to parse response: {response.strip()}") diff --git a/tests/test_outline.py b/tests/test_outline.py new file mode 100644 index 000000000..e3b6383e4 --- /dev/null +++ b/tests/test_outline.py @@ -0,0 +1,93 @@ +from unittest.mock import MagicMock + +import pytest + +from outlines.outline import Outline + + +def test_outline_int_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "6" + + # Define the template function + def template(a: int) -> str: + return f"What is 2 times {a}?" + + # Create an instance of Outline + fn = Outline(model, template, int) + + # Test the callable object + result = fn(3) + assert result == 6 + + +def test_outline_str_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "'Hello, world!'" + + # Define the template function + def template(a: int) -> str: + return f"Say 'Hello, world!' {a} times" + + # Create an instance of Outline + fn = Outline(model, template, str) + + # Test the callable object + result = fn(1) + assert result == "Hello, world!" + + +def test_outline_str_input(): + # Mock the model + model = MagicMock() + model.generate.return_value = "'Hi, Mark!'" + + # Define the template function + def template(a: str) -> str: + return f"Say hi to {a}" + + # Create an instance of Outline + fn = Outline(model, template, str) + + # Test the callable object + result = fn(1) + assert result == "Hi, Mark!" + + +def test_outline_invalid_output(): + # Mock the model + model = MagicMock() + model.generate.return_value = "not a number" + + # Define the template function + def template(a: int) -> str: + return f"What is 2 times {a}?" + + # Create an instance of Outline + fn = Outline(model, template, int) + + # Test the callable object with invalid output + with pytest.raises(ValueError): + fn(3) + + +def test_outline_mismatched_output_type(): + # Mock the model + model = MagicMock() + model.generate.return_value = "'Hello, world!'" + + # Define the template function + def template(a: int) -> str: + return f"What is 2 times {a}?" + + # Create an instance of Outline + fn = Outline(model, template, int) + + # Test the callable object with mismatched output type + with pytest.raises( + ValueError, + match="Unable to parse response: 'Hello, world!'", + ): + fn(3)