Skip to content

Commit

Permalink
I cooked, idk what to say (#74)
Browse files Browse the repository at this point in the history
* wip

* Anthropic support and refactors tests.

* Add Gemini, make streaming more resilient

* Add an OCR example with validations.

* Please add the streaming UI example.
  • Loading branch information
thmsmlr authored Sep 24, 2024
1 parent 9a18473 commit 0db54d0
Show file tree
Hide file tree
Showing 19 changed files with 1,404 additions and 370 deletions.
31 changes: 31 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "mix_task",
"name": "mix (Default task)",
"request": "launch",
"projectDir": "${workspaceRoot}"
},
{
"type": "mix_task",
"name": "mix test",
"request": "launch",
"task": "test",
"taskArgs": [
"--trace"
],
"debugAutoInterpretAllModules": false,
"debugInterpretModulesPatterns": ["Instructor.*"],
"startApps": true,
"projectDir": "${workspaceRoot}",
"requireFiles": [
"test/**/test_helper.exs",
"test/**/*_test.exs"
]
}
]
}
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ end

## TODO

- [x] Top-level array support
- [x] Gemini
- [x] tools mode
- [ ] json mode
- [x] json_schema mode
- [x] Figure out a solution for OpenAI's json_schema mode not supporting arbitrary maps.
- [ ] Partial Schemaless doesn't work since fields are set to required in Ecto.


- [ ] llamacpp adapter broken, needs to support openai input/output API
- [ ] GBNF should enforce required properties on objects, currently they're optional.
- [ ] GBNF limit the number of digits in number tokens -- small models can sometimes run off to infinit digits
Expand Down
123 changes: 54 additions & 69 deletions lib/instructor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ defmodule Instructor do
{{:array, response_model}, true} ->
do_streaming_array_chat_completion(response_model, params, config)

{{:array, response_model}, false} ->
params = Keyword.put(params, :stream, true)

do_streaming_array_chat_completion(response_model, params, config)
|> Enum.to_list()

{response_model, false} ->
do_chat_completion(response_model, params, config)

Expand Down Expand Up @@ -279,7 +285,6 @@ defmodule Instructor do
end

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
fn -> {nil, []} end,
Expand Down Expand Up @@ -342,7 +347,6 @@ defmodule Instructor do
params = params_for_mode(mode, wrapped_model, params)

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
fn -> nil end,
Expand Down Expand Up @@ -389,7 +393,6 @@ defmodule Instructor do
params = params_for_mode(mode, wrapped_model, params)

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Jaxon.Stream.from_enumerable()
|> Jaxon.Stream.query([:root, "value", :all])
|> Stream.map(fn params ->
Expand Down Expand Up @@ -425,20 +428,14 @@ defmodule Instructor do
{%{}, response_model}
end

with {:llm, {:ok, response}} <- {:llm, adapter(config).chat_completion(params, config)},
{:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
changeset <- cast_all(model, params),
{:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
{:validation, call_validate(response_model, changeset, validation_context), response} do
with {:ok, raw_response, params} <- do_adapter_chat_completion(params, config),
{%Ecto.Changeset{valid?: true} = changeset, raw_response} <-
{cast_all(model, params), raw_response},
{%Ecto.Changeset{valid?: true} = changeset, _raw_response} <-
{call_validate(response_model, changeset, validation_context), raw_response} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{:llm, {:error, error}} ->
{:error, "LLM Adapter Error: #{inspect(error)}"}

{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}

{:validation, changeset, response} ->
{%Ecto.Changeset{} = changeset, raw_response} ->
if max_retries > 0 do
errors = Instructor.ErrorFormatter.format_errors(changeset)

Expand All @@ -451,7 +448,7 @@ defmodule Instructor do
|> Keyword.put(:max_retries, max_retries - 1)
|> Keyword.update(:messages, [], fn messages ->
messages ++
echo_response(response) ++
reask_messages(raw_response, params, config) ++
[
%{
role: "system",
Expand All @@ -477,57 +474,25 @@ defmodule Instructor do
end
end

defp parse_response_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:tools, %{
"choices" => [
%{"message" => %{"tool_calls" => [%{"function" => %{"arguments" => args}}]}}
]
}),
do: Jason.decode(args)

defp parse_stream_chunk_for_mode(:md_json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}),
do: chunk

defp parse_stream_chunk_for_mode(:json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}),
do: chunk

defp parse_stream_chunk_for_mode(:tools, %{
"choices" => [
%{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}}
]
}),
do: chunk

defp parse_stream_chunk_for_mode(_, %{"choices" => [%{"finish_reason" => "stop"}]}), do: ""

defp echo_response(%{
"choices" => [
%{
"message" =>
%{
"tool_calls" => [
%{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} =
function
]
} = message
}
]
}) do
[
Map.put(message, "content", function |> Jason.encode!())
|> Map.new(fn {k, v} -> {String.to_atom(k), v} end),
%{
role: "tool",
tool_call_id: tool_call_id,
name: name,
content: args
}
]
defp do_adapter_chat_completion(params, config) do
case adapter(config).chat_completion(params, config) do
{:ok, response, content} ->
{:ok, response, content}

{:error, reason} ->
{:error, "LLM Adapter Error: #{inspect(reason)}"}
end
end

defp reask_messages(raw_response, params, config) do
adp = adapter(config)

if function_exported?(adp, :reask_messages, 3) do
adp.reask_messages(raw_response, params, config)
else
Logger.debug("Adapter #{inspect(adp)} does not implement reask_messages/3")
[]
end
end

defp params_for_mode(mode, response_model, params) do
Expand All @@ -552,6 +517,8 @@ defmodule Instructor do
#{json_schema}
#{additional_definitions}
Make sure to return an instance of the JSON, not the schema itself.
"""
}

Expand All @@ -568,6 +535,9 @@ defmodule Instructor do
:json ->
[sys_message | messages]

:json_schema ->
messages

:tools ->
messages
end
Expand All @@ -583,6 +553,17 @@ defmodule Instructor do
type: "json_object"
})

:json_schema ->
params
|> Keyword.put(:response_format, %{
type: "json_schema",
json_schema: %{
schema: Jason.decode!(json_schema),
name: "schema",
strict: true
}
})

:tools ->
params
|> Keyword.put(:tools, [
Expand Down Expand Up @@ -619,6 +600,10 @@ defmodule Instructor do
end
end

defp adapter(%{adapter: adapter}) when is_atom(adapter), do: adapter
defp adapter(_), do: Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
defp adapter(config) do
case config[:adapter] do
nil -> Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
adapter -> adapter
end
end
end
10 changes: 9 additions & 1 deletion lib/instructor/adapter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,13 @@ defmodule Instructor.Adapter do
@moduledoc """
Behavior for `Instructor.Adapter`.
"""
@callback chat_completion([Keyword.t()], any()) :: any()

@type params :: [Keyword.t()]
@type config :: any()
@type raw_response :: any()

@callback chat_completion(params(), config()) ::
Stream.t() | {:ok, raw_response(), String.t()} | {:error, String.t()}

@callback reask_messages(raw_response(), params(), config()) :: [map()]
end
Loading

0 comments on commit 0db54d0

Please sign in to comment.