Skip to content

Commit

Permalink
Handle function enums in JSON Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz authored and rlouf committed Nov 27, 2024
1 parent 7a9baad commit 2312c82
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 4 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,33 @@ print(add(**result))

A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places!

You can also embed various functions into an enum to generate params:

```python
from enum import Enum
from functools import partial

import outlines


def add(a: int, b: int) -> int:
return a + b

def mul(c: float, d: float) -> float:
return c * d

class Operation(Enum):
add = partial(add)
mul = partial(mul)

model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1")
generator = outlines.generate.json(model, add)
result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.")

print(result)
# {'c': -3.14, 'd': 1.5}
```

## Prompting

Building prompts can get messy. **Outlines** makes it easier to write and manage
Expand Down
18 changes: 17 additions & 1 deletion outlines/fsm/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import re
import warnings
from enum import Enum
from typing import Callable, Optional, Tuple, Type, Union

from jsonschema.protocols import Validator
Expand Down Expand Up @@ -306,6 +307,8 @@ def to_regex(
for choice in instance["enum"]:
if type(choice) in [int, float, bool, type(None), str]:
choices.append(re.escape(json.dumps(choice)))
elif isinstance(choice, dict):
choices.append(to_regex(resolver, choice, whitespace_pattern))
else:
raise TypeError(f"Unsupported data type in enum: {type(choice)}")
return f"({'|'.join(choices)})"
Expand Down Expand Up @@ -524,7 +527,7 @@ def to_regex(
)


def get_schema_from_signature(fn: Callable) -> str:
def get_schema_from_signature(fn: Callable) -> dict:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
Expand All @@ -550,3 +553,16 @@ def get_schema_from_signature(fn: Callable) -> str:
model = create_model(fn_name, **arguments)

return model.model_json_schema()


def get_schema_from_enum(myenum: type[Enum]) -> dict:
if len(myenum) == 0:
raise ValueError(
f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
)
choices = [
get_schema_from_signature(elt.value.func) if callable(elt.value) else elt.value
for elt in myenum
]
schema = {"title": myenum.__name__, "enum": choices}
return schema
12 changes: 11 additions & 1 deletion outlines/generate/json.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json as pyjson
from enum import Enum
from functools import singledispatch
from typing import Callable, Optional, Union

from pydantic import BaseModel

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
from outlines.fsm.json_schema import (
build_regex_from_schema,
get_schema_from_enum,
get_schema_from_signature,
)
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
Expand Down Expand Up @@ -48,6 +53,11 @@ def json(
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif isinstance(schema_object, type(Enum)):
schema = pyjson.dumps(get_schema_from_enum(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: pyjson.loads(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
Expand Down
59 changes: 57 additions & 2 deletions tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import re
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import List, Literal, Union

import interegular
Expand All @@ -19,6 +22,7 @@
UUID,
WHITESPACE,
build_regex_from_schema,
get_schema_from_enum,
get_schema_from_signature,
to_regex,
)
Expand Down Expand Up @@ -237,8 +241,26 @@ def test_match_number(pattern, does_match):
),
# Enum mix of types
(
{"title": "Foo", "enum": [6, 5.3, "potato", True, None]},
r'(6|5\.3|"potato"|true|null)',
{
"title": "Foo",
"enum": [
6,
5.3,
"potato",
True,
None,
{
"properties": {
"a": {"title": "A", "type": "number"},
"b": {"title": "B", "type": "number"},
},
"required": ["a", "b"],
"title": "add",
"type": "object",
},
],
},
r'(6|5\.3|"potato"|true|null|\{[ ]?"a"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?,[ ]?"b"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\})',
[
("6", True),
("5.3", True),
Expand All @@ -248,6 +270,8 @@ def test_match_number(pattern, does_match):
("523", False),
("True", False),
("None", False),
('{"a": -1.0, "b": 1.1}', True),
('{"a": "a", "b": 1.1}', False),
],
),
# integer
Expand Down Expand Up @@ -1039,3 +1063,34 @@ class Model(BaseModel):

# check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm()
interegular.parse_pattern(pattern).to_fsm()


def add(a: float, b: float) -> float:
return a + b


class MyEnum(Enum):
add = partial(add)
a = "a"
b = 2


# if you don't register your function as callable, you will get an empty enum
class EmptyEnum(Enum):
add = add


@pytest.mark.parametrize(
"enum,expectation",
[
(MyEnum, nullcontext()),
(EmptyEnum, pytest.raises(ValueError)),
],
)
def test_enum_schema(enum, expectation):
with expectation:
result = get_schema_from_enum(enum)
assert result["title"] == enum.__name__
assert len(result["enum"]) == len(enum)
for elt in result["enum"]:
assert type(elt) in [int, float, bool, type(None), str, dict]
24 changes: 24 additions & 0 deletions tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
from enum import Enum
from functools import partial
from typing import List, Union

import pytest
Expand Down Expand Up @@ -354,6 +355,29 @@ class User(BaseModel):
assert result.user_id in [1, 2]


def add(a: int, b: int) -> int:
return a + b


def mul(c: float, d: float) -> float:
return c * d


def test_transformers_json_function_enum(model):
prompt = "Output some JSON "

class Operation(Enum):
add = partial(add)
mul = partial(mul)

result = generate.json(model, Operation)(prompt, seed=0)
assert isinstance(result, dict)
assert len(result) == 2
for k, v in result.items():
assert k in ["a", "b", "c", "d"]
assert isinstance(v, (int, float))


def test_transformers_json_array(model):
prompt = "Output some JSON "

Expand Down

0 comments on commit 2312c82

Please sign in to comment.