Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I cooked, idk what to say #74

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "mix_task",
"name": "mix (Default task)",
"request": "launch",
"projectDir": "${workspaceRoot}"
},
{
"type": "mix_task",
"name": "mix test",
"request": "launch",
"task": "test",
"taskArgs": [
"--trace"
],
"debugAutoInterpretAllModules": false,
"debugInterpretModulesPatterns": ["Instructor.*"],
"startApps": true,
"projectDir": "${workspaceRoot}",
"requireFiles": [
"test/**/test_helper.exs",
"test/**/*_test.exs"
]
}
]
}
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ end

## TODO

- [x] Top-level array support
- [x] Gemini
- [x] tools mode
- [ ] json mode
- [x] json_schema mode
- [x] Figure out a solution for OpenAI's json_schema mode not supporting arbitrary maps.
- [ ] Partial Schemaless doesn't work since fields are set to required in Ecto.


- [ ] llamacpp adapter broken, needs to support openai input/output API
- [ ] GBNF should enforce required properties on objects, currently they're optional.
- [ ] GBNF limit the number of digits in number tokens -- small models can sometimes run off to infinit digits
Expand Down
123 changes: 54 additions & 69 deletions lib/instructor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ defmodule Instructor do
{{:array, response_model}, true} ->
do_streaming_array_chat_completion(response_model, params, config)

{{:array, response_model}, false} ->
params = Keyword.put(params, :stream, true)

do_streaming_array_chat_completion(response_model, params, config)
|> Enum.to_list()

{response_model, false} ->
do_chat_completion(response_model, params, config)

Expand Down Expand Up @@ -279,7 +285,6 @@ defmodule Instructor do
end

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
fn -> {nil, []} end,
Expand Down Expand Up @@ -342,7 +347,6 @@ defmodule Instructor do
params = params_for_mode(mode, wrapped_model, params)

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Instructor.JSONStreamParser.parse()
|> Stream.transform(
fn -> nil end,
Expand Down Expand Up @@ -389,7 +393,6 @@ defmodule Instructor do
params = params_for_mode(mode, wrapped_model, params)

adapter(config).chat_completion(params, config)
|> Stream.map(&parse_stream_chunk_for_mode(mode, &1))
|> Jaxon.Stream.from_enumerable()
|> Jaxon.Stream.query([:root, "value", :all])
|> Stream.map(fn params ->
Expand Down Expand Up @@ -425,20 +428,14 @@ defmodule Instructor do
{%{}, response_model}
end

with {:llm, {:ok, response}} <- {:llm, adapter(config).chat_completion(params, config)},
{:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)},
changeset <- cast_all(model, params),
{:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <-
{:validation, call_validate(response_model, changeset, validation_context), response} do
with {:ok, raw_response, params} <- do_adapter_chat_completion(params, config),
{%Ecto.Changeset{valid?: true} = changeset, raw_response} <-
{cast_all(model, params), raw_response},
{%Ecto.Changeset{valid?: true} = changeset, _raw_response} <-
{call_validate(response_model, changeset, validation_context), raw_response} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{:llm, {:error, error}} ->
{:error, "LLM Adapter Error: #{inspect(error)}"}

{:valid_json, {:error, error}} ->
{:error, "Invalid JSON returned from LLM: #{inspect(error)}"}

{:validation, changeset, response} ->
{%Ecto.Changeset{} = changeset, raw_response} ->
if max_retries > 0 do
errors = Instructor.ErrorFormatter.format_errors(changeset)

Expand All @@ -451,7 +448,7 @@ defmodule Instructor do
|> Keyword.put(:max_retries, max_retries - 1)
|> Keyword.update(:messages, [], fn messages ->
messages ++
echo_response(response) ++
reask_messages(raw_response, params, config) ++
[
%{
role: "system",
Expand All @@ -477,57 +474,25 @@ defmodule Instructor do
end
end

defp parse_response_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:tools, %{
"choices" => [
%{"message" => %{"tool_calls" => [%{"function" => %{"arguments" => args}}]}}
]
}),
do: Jason.decode(args)

defp parse_stream_chunk_for_mode(:md_json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}),
do: chunk

defp parse_stream_chunk_for_mode(:json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}),
do: chunk

defp parse_stream_chunk_for_mode(:tools, %{
"choices" => [
%{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}}
]
}),
do: chunk

defp parse_stream_chunk_for_mode(_, %{"choices" => [%{"finish_reason" => "stop"}]}), do: ""

defp echo_response(%{
"choices" => [
%{
"message" =>
%{
"tool_calls" => [
%{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} =
function
]
} = message
}
]
}) do
[
Map.put(message, "content", function |> Jason.encode!())
|> Map.new(fn {k, v} -> {String.to_atom(k), v} end),
%{
role: "tool",
tool_call_id: tool_call_id,
name: name,
content: args
}
]
defp do_adapter_chat_completion(params, config) do
case adapter(config).chat_completion(params, config) do
{:ok, response, content} ->
{:ok, response, content}

{:error, reason} ->
{:error, "LLM Adapter Error: #{inspect(reason)}"}
end
end

defp reask_messages(raw_response, params, config) do
adp = adapter(config)

if function_exported?(adp, :reask_messages, 3) do
adp.reask_messages(raw_response, params, config)
else
Logger.debug("Adapter #{inspect(adp)} does not implement reask_messages/3")
[]
end
end

defp params_for_mode(mode, response_model, params) do
Expand All @@ -552,6 +517,8 @@ defmodule Instructor do
#{json_schema}

#{additional_definitions}

Make sure to return an instance of the JSON, not the schema itself.
"""
}

Expand All @@ -568,6 +535,9 @@ defmodule Instructor do
:json ->
[sys_message | messages]

:json_schema ->
messages

:tools ->
messages
end
Expand All @@ -583,6 +553,17 @@ defmodule Instructor do
type: "json_object"
})

:json_schema ->
params
|> Keyword.put(:response_format, %{
type: "json_schema",
json_schema: %{
schema: Jason.decode!(json_schema),
name: "schema",
strict: true
}
})

:tools ->
params
|> Keyword.put(:tools, [
Expand Down Expand Up @@ -619,6 +600,10 @@ defmodule Instructor do
end
end

defp adapter(%{adapter: adapter}) when is_atom(adapter), do: adapter
defp adapter(_), do: Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
defp adapter(config) do
case config[:adapter] do
nil -> Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI)
adapter -> adapter
end
end
end
10 changes: 9 additions & 1 deletion lib/instructor/adapter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,13 @@ defmodule Instructor.Adapter do
@moduledoc """
Behavior for `Instructor.Adapter`.
"""
@callback chat_completion([Keyword.t()], any()) :: any()

@type params :: [Keyword.t()]
@type config :: any()
@type raw_response :: any()

@callback chat_completion(params(), config()) ::
Stream.t() | {:ok, raw_response(), String.t()} | {:error, String.t()}

@callback reask_messages(raw_response(), params(), config()) :: [map()]
end
Loading
Loading