diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..79f5763 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,31 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "mix_task", + "name": "mix (Default task)", + "request": "launch", + "projectDir": "${workspaceRoot}" + }, + { + "type": "mix_task", + "name": "mix test", + "request": "launch", + "task": "test", + "taskArgs": [ + "--trace" + ], + "debugAutoInterpretAllModules": false, + "debugInterpretModulesPatterns": ["Instructor.*"], + "startApps": true, + "projectDir": "${workspaceRoot}", + "requireFiles": [ + "test/**/test_helper.exs", + "test/**/*_test.exs" + ] + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 5b99295..6e38047 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,15 @@ end ## TODO +- [x] Top-level array support +- [x] Gemini + - [x] tools mode + - [ ] json mode + - [x] json_schema mode +- [x] Figure out a solution for OpenAI's json_schema mode not supporting arbitrary maps. +- [ ] Partial Schemaless doesn't work since fields are set to required in Ecto. + + - [ ] llamacpp adapter broken, needs to support openai input/output API - [ ] GBNF should enforce required properties on objects, currently they're optional. - [ ] GBNF limit the number of digits in number tokens -- small models can sometimes run off to infinit digits diff --git a/lib/instructor.ex b/lib/instructor.ex index cc9579b..50e99bd 100644 --- a/lib/instructor.ex +++ b/lib/instructor.ex @@ -138,6 +138,12 @@ defmodule Instructor do {{:array, response_model}, true} -> do_streaming_array_chat_completion(response_model, params, config) + {{:array, response_model}, false} -> + params = Keyword.put(params, :stream, true) + + do_streaming_array_chat_completion(response_model, params, config) + |> Enum.to_list() + {response_model, false} -> do_chat_completion(response_model, params, config) @@ -279,7 +285,6 @@ defmodule Instructor do end adapter(config).chat_completion(params, config) - |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Instructor.JSONStreamParser.parse() |> Stream.transform( fn -> {nil, []} end, @@ -342,7 +347,6 @@ defmodule Instructor do params = params_for_mode(mode, wrapped_model, params) adapter(config).chat_completion(params, config) - |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Instructor.JSONStreamParser.parse() |> Stream.transform( fn -> nil end, @@ -389,7 +393,6 @@ defmodule Instructor do params = params_for_mode(mode, wrapped_model, params) adapter(config).chat_completion(params, config) - |> Stream.map(&parse_stream_chunk_for_mode(mode, &1)) |> Jaxon.Stream.from_enumerable() |> Jaxon.Stream.query([:root, "value", :all]) |> Stream.map(fn params -> @@ -425,20 +428,14 @@ defmodule Instructor do {%{}, response_model} end - with {:llm, {:ok, response}} <- {:llm, adapter(config).chat_completion(params, config)}, - {:valid_json, {:ok, params}} <- {:valid_json, parse_response_for_mode(mode, response)}, - changeset <- cast_all(model, params), - {:validation, %Ecto.Changeset{valid?: true} = changeset, _response} <- - {:validation, call_validate(response_model, changeset, validation_context), response} do + with {:ok, raw_response, params} <- do_adapter_chat_completion(params, config), + {%Ecto.Changeset{valid?: true} = changeset, raw_response} <- + {cast_all(model, params), raw_response}, + {%Ecto.Changeset{valid?: true} = changeset, _raw_response} <- + {call_validate(response_model, changeset, validation_context), raw_response} do {:ok, changeset |> Ecto.Changeset.apply_changes()} else - {:llm, {:error, error}} -> - {:error, "LLM Adapter Error: #{inspect(error)}"} - - {:valid_json, {:error, error}} -> - {:error, "Invalid JSON returned from LLM: #{inspect(error)}"} - - {:validation, changeset, response} -> + {%Ecto.Changeset{} = changeset, raw_response} -> if max_retries > 0 do errors = Instructor.ErrorFormatter.format_errors(changeset) @@ -451,7 +448,7 @@ defmodule Instructor do |> Keyword.put(:max_retries, max_retries - 1) |> Keyword.update(:messages, [], fn messages -> messages ++ - echo_response(response) ++ + reask_messages(raw_response, params, config) ++ [ %{ role: "system", @@ -477,57 +474,25 @@ defmodule Instructor do end end - defp parse_response_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}), - do: Jason.decode(content) - - defp parse_response_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}), - do: Jason.decode(content) - - defp parse_response_for_mode(:tools, %{ - "choices" => [ - %{"message" => %{"tool_calls" => [%{"function" => %{"arguments" => args}}]}} - ] - }), - do: Jason.decode(args) - - defp parse_stream_chunk_for_mode(:md_json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}), - do: chunk - - defp parse_stream_chunk_for_mode(:json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}), - do: chunk - - defp parse_stream_chunk_for_mode(:tools, %{ - "choices" => [ - %{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}} - ] - }), - do: chunk - - defp parse_stream_chunk_for_mode(_, %{"choices" => [%{"finish_reason" => "stop"}]}), do: "" - - defp echo_response(%{ - "choices" => [ - %{ - "message" => - %{ - "tool_calls" => [ - %{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} = - function - ] - } = message - } - ] - }) do - [ - Map.put(message, "content", function |> Jason.encode!()) - |> Map.new(fn {k, v} -> {String.to_atom(k), v} end), - %{ - role: "tool", - tool_call_id: tool_call_id, - name: name, - content: args - } - ] + defp do_adapter_chat_completion(params, config) do + case adapter(config).chat_completion(params, config) do + {:ok, response, content} -> + {:ok, response, content} + + {:error, reason} -> + {:error, "LLM Adapter Error: #{inspect(reason)}"} + end + end + + defp reask_messages(raw_response, params, config) do + adp = adapter(config) + + if function_exported?(adp, :reask_messages, 3) do + adp.reask_messages(raw_response, params, config) + else + Logger.debug("Adapter #{inspect(adp)} does not implement reask_messages/3") + [] + end end defp params_for_mode(mode, response_model, params) do @@ -552,6 +517,8 @@ defmodule Instructor do #{json_schema} #{additional_definitions} + + Make sure to return an instance of the JSON, not the schema itself. """ } @@ -568,6 +535,9 @@ defmodule Instructor do :json -> [sys_message | messages] + :json_schema -> + messages + :tools -> messages end @@ -583,6 +553,17 @@ defmodule Instructor do type: "json_object" }) + :json_schema -> + params + |> Keyword.put(:response_format, %{ + type: "json_schema", + json_schema: %{ + schema: Jason.decode!(json_schema), + name: "schema", + strict: true + } + }) + :tools -> params |> Keyword.put(:tools, [ @@ -619,6 +600,10 @@ defmodule Instructor do end end - defp adapter(%{adapter: adapter}) when is_atom(adapter), do: adapter - defp adapter(_), do: Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI) + defp adapter(config) do + case config[:adapter] do + nil -> Application.get_env(:instructor, :adapter, Instructor.Adapters.OpenAI) + adapter -> adapter + end + end end diff --git a/lib/instructor/adapter.ex b/lib/instructor/adapter.ex index 5a00249..76a6cf3 100644 --- a/lib/instructor/adapter.ex +++ b/lib/instructor/adapter.ex @@ -2,5 +2,13 @@ defmodule Instructor.Adapter do @moduledoc """ Behavior for `Instructor.Adapter`. """ - @callback chat_completion([Keyword.t()], any()) :: any() + + @type params :: [Keyword.t()] + @type config :: any() + @type raw_response :: any() + + @callback chat_completion(params(), config()) :: + Stream.t() | {:ok, raw_response(), String.t()} | {:error, String.t()} + + @callback reask_messages(raw_response(), params(), config()) :: [map()] end diff --git a/lib/instructor/adapters/anthropic.ex b/lib/instructor/adapters/anthropic.ex new file mode 100644 index 0000000..e21d809 --- /dev/null +++ b/lib/instructor/adapters/anthropic.ex @@ -0,0 +1,144 @@ +defmodule Instructor.Adapters.Anthropic do + @moduledoc """ + Anthropic adapter for Instructor. + """ + @behaviour Instructor.Adapter + + alias Instructor.SSEStreamParser + + @impl true + def chat_completion(params, user_config \\ nil) do + config = config(user_config) + + # Peel off instructor only parameters + {_, params} = Keyword.pop(params, :response_model) + {_, params} = Keyword.pop(params, :validation_context) + {_, params} = Keyword.pop(params, :max_retries) + {mode, params} = Keyword.pop(params, :mode) + stream = Keyword.get(params, :stream, false) + params = Enum.into(params, %{}) + + {system_prompt, messages} = params.messages |> Enum.split_with(&(&1[:role] == "system")) + system_prompt = system_prompt |> Enum.map(& &1[:content]) |> Enum.join("\n") + + [tool] = params.tools + tool = tool.function + + tool = + tool + |> Map.put("input_schema", tool["parameters"]) + |> Map.delete("parameters") + + params = + params + |> Map.put(:messages, messages) + |> Map.put(:tools, [tool]) + |> Map.put(:tool_choice, %{"type" => "tool", "name" => tool["name"]}) + |> Map.put(:system, system_prompt) + + if stream do + do_streaming_chat_completion(mode, params, config) + else + do_chat_completion(mode, params, config) + end + end + + defp do_chat_completion(mode, params, config) do + options = get_anthropic_http_opts(config) |> Keyword.merge(json: params) + + case Req.post(url(config), options) do + {:ok, %Req.Response{status: 200, body: body} = response} -> + {:ok, response, parse_response_for_mode(mode, body)} + + {:ok, %Req.Response{status: status, body: body}} -> + {:error, "Unexpected HTTP response code: #{status}\n#{inspect(body)}"} + + e -> + e + end + end + + defp do_streaming_chat_completion(mode, params, config) do + pid = self() + options = get_anthropic_http_opts(config) |> Keyword.merge(json: params) + + Stream.resource( + fn -> + Task.async(fn -> + options = + Keyword.merge(options, + into: fn {:data, data}, {req, resp} -> + send(pid, data) + {:cont, {req, resp}} + end + ) + + Req.post!(url(config), options) + send(pid, :done) + end) + end, + fn task -> + receive do + :done -> + {:halt, task} + + data -> + {[data], task} + after + 15_000 -> + {:halt, task} + end + end, + fn task -> Task.await(task) end + ) + |> SSEStreamParser.parse() + |> Stream.map(fn chunk -> parse_stream_chunk_for_mode(mode, chunk) end) + end + + defp get_anthropic_http_opts(config) do + Keyword.merge(http_options(config), + headers: [{"x-api-key", api_key(config)}, {"anthropic-version", " 2023-06-01"}] + ) + end + + defp parse_stream_chunk_for_mode(:tools, %{"type" => event}) + when event in [ + "message_start", + "ping", + "content_block_start", + "content_block_stop", + "message_stop", + "message_delta", + "completion" + ] do + "" + end + + defp parse_stream_chunk_for_mode(:tools, %{ + "type" => "content_block_delta", + "delta" => %{"partial_json" => delta, "type" => "input_json_delta"} + }) do + delta + end + + defp parse_response_for_mode(:tools, %{"content" => [%{"input" => args, "type" => "tool_use"}]}) do + args + end + + defp url(config), do: api_url(config) <> "/v1/messages" + + defp api_url(config), do: Keyword.fetch!(config, :api_url) + defp api_key(config), do: Keyword.fetch!(config, :api_key) + defp http_options(config), do: Keyword.fetch!(config, :http_options) + + defp config(nil), do: config(Application.get_env(:instructor, :anthropic, [])) + + defp config(base_config) do + default_config = [ + api_url: "https://api.anthropic.com/", + http_options: [receive_timeout: 60_000] + ] + + Keyword.merge(default_config, base_config) + end +end diff --git a/lib/instructor/adapters/gemini.ex b/lib/instructor/adapters/gemini.ex new file mode 100644 index 0000000..711419b --- /dev/null +++ b/lib/instructor/adapters/gemini.ex @@ -0,0 +1,310 @@ +defmodule Instructor.Adapters.Gemini do + @moduledoc """ + Adapter for Google Gemini. + + """ + @behaviour Instructor.Adapter + @supported_modes [:tools, :json_schema] + + alias Instructor.JSONSchema + + def gemini_req, + do: + Req.new() + |> Req.Request.register_options([:rpc_function]) + |> Req.Request.append_request_steps( + append_rpc_function: fn request -> + rpc_function = request.options[:rpc_function] + + if rpc_function do + update_in(request.url.path, fn + nil -> nil + path -> path <> inspect(rpc_function) + end) + else + request + end + end + ) + + @doc """ + Run a completion against Google's Gemini API + Accepts OpenAI API arguments and converts to Gemini Args to perform the completion. + Defaults to JSON mode within the Gemini API + """ + @impl true + def chat_completion(params, user_config \\ nil) do + config = config(user_config) + + # Peel off instructor only parameters + {_, params} = Keyword.pop(params, :response_model) + {_, params} = Keyword.pop(params, :validation_context) + {_, params} = Keyword.pop(params, :max_retries) + {mode, params} = Keyword.pop(params, :mode) + stream = Keyword.get(params, :stream, false) + params = Enum.into(params, %{}) + + if mode not in @supported_modes do + raise "Unsupported Gemini mode #{mode}. Supported modes: #{inspect(@supported_modes)}" + end + + # Format the messages into the correct format for Geminic + {messages, params} = Map.pop!(params, :messages) + + {system_instruction, contents} = + messages + |> Enum.reduce({%{role: "system", parts: []}, []}, fn + %{role: "assistant", content: content}, {system_instructions, history} -> + {system_instructions, [%{role: "model", parts: [%{text: content}]} | history]} + + %{role: "user", content: content}, {system_instructions, history} -> + {system_instructions, [%{role: "user", parts: [%{text: content}]} | history]} + + %{role: "system", content: content}, {system_instructions, history} -> + part = %{text: content} + {Map.update!(system_instructions, :parts, fn parts -> [part | parts] end), history} + end) + + system_instruction = Map.update!(system_instruction, :parts, &Enum.reverse/1) + contents = Enum.reverse(contents) + + # Split out the model config params from the rest of the params + {model_config_params, params} = + Map.split(params, [:top_k, :top_p, :max_tokens, :temperature, :n, :stop]) + + generation_config = + model_config_params + |> Enum.into(%{}, fn + {:stop, stops} -> + {"stopSequences", stops} + + {:n, n} -> + {"candidateCount", n} + + {:max_tokens, max_tokens} -> + {"maxOutputTokens", max_tokens} + + {other_key, value} -> + {Atom.to_string(other_key) |> snake_to_camel(), value} + end) + + params = + if system_instruction.parts != [], + do: Map.put(params, :systemInstruction, system_instruction), + else: params + + params = Map.put(params, :contents, contents) + + params = + case params do + %{response_format: %{json_schema: %{schema: schema}}} -> + generation_config = + generation_config + |> Map.put("response_mime_type", "application/json") + |> Map.put("response_schema", map_schema(schema)) + + params + |> Map.put(:generationConfig, generation_config) + |> Map.delete(:response_format) + + %{tools: tools} -> + tools = [ + %{ + function_declarations: + Enum.map(tools, fn %{function: tool} -> + %{ + name: tool["name"], + description: tool["description"], + parameters: map_schema(tool["parameters"]) + } + end) + } + ] + + params + |> Map.put(:generationConfig, generation_config) + |> Map.put(:tools, tools) + |> Map.delete(:tool_choice) + + _ -> + params + end + + if stream do + do_streaming_chat_completion(mode, params, config) + else + do_chat_completion(mode, params, config) + end + end + + defp do_streaming_chat_completion(mode, params, config) do + pid = self() + options = http_options(config) + {model, params} = Map.pop!(params, :model) + {_, params} = Map.pop!(params, :stream) + + Stream.resource( + fn -> + Task.async(fn -> + options = + Keyword.merge(options, + url: url(config) <> "?alt=sse", + path_params: [model: model, api_version: api_version(config)], + headers: %{"x-goog-api-key" => api_key(config)}, + json: params, + rpc_function: :streamGenerateContent, + into: fn {:data, data}, {req, resp} -> + chunks = + data + |> String.split("\n") + |> Enum.filter(fn line -> + String.starts_with?(line, "data: {") + end) + |> Enum.map(fn line -> + line + |> String.replace_prefix("data: ", "") + |> Jason.decode!() + |> then(&parse_stream_chunk_for_mode(mode, &1)) + end) + + for chunk <- chunks do + send(pid, chunk) + end + + {:cont, {req, resp}} + end + ) + + Req.merge(gemini_req(), options) + |> Req.post!() + + send(pid, :done) + end) + end, + fn task -> + receive do + :done -> + {:halt, task} + + data -> + {[data], task} + after + 15_000 -> + {:halt, task} + end + end, + fn task -> Task.await(task) end + ) + end + + defp do_chat_completion(mode, params, config) do + {model, params} = Map.pop!(params, :model) + + response = + Req.merge(gemini_req(), http_options(config)) + |> Req.post( + url: url(config), + path_params: [model: model, api_version: api_version(config)], + headers: %{"x-goog-api-key" => api_key(config)}, + json: params, + rpc_function: :generateContent + ) + + with {:ok, %Req.Response{status: 200, body: body} = response} <- response, + {:ok, body} <- parse_response_for_mode(mode, body) do + {:ok, response, body} + else + {:ok, %Req.Response{status: status, body: body}} -> + {:error, "Unexpected HTTP response code: #{status}\n#{inspect(body)}"} + + e -> + e + end + end + + defp parse_response_for_mode(:tools, %{ + "candidates" => [ + %{"content" => %{"parts" => [%{"functionCall" => %{"args" => args}}]}} + ] + }) do + {:ok, args} + end + + defp parse_response_for_mode(:json_schema, %{ + "candidates" => [ + %{"content" => %{"parts" => [%{"text" => text}]}} + ] + }) do + Jason.decode(text) + end + + defp parse_stream_chunk_for_mode( + :tools, + %{ + "candidates" => [ + %{ + "content" => %{ + "parts" => [%{"functionCall" => %{"args" => args}}] + } + } + ] + } + ) do + args + end + + defp parse_stream_chunk_for_mode(:json_schema, %{ + "candidates" => [ + %{ + "content" => %{ + "parts" => [%{"text" => chunk}] + } + } + ] + }) do + chunk + end + + defp map_schema(schema) do + JSONSchema.traverse_and_update(schema, fn + %{"type" => _} = x + when is_map_key(x, "format") or is_map_key(x, "pattern") or + is_map_key(x, "title") or is_map_key(x, "additionalProperties") -> + Map.drop(x, ["format", "pattern", "title", "additionalProperties"]) + + x -> + x + end) + end + + defp snake_to_camel(snake_case_string) do + snake_case_string + |> String.split("_") + |> Enum.with_index() + |> Enum.map(fn {word, index} -> + if index == 0 do + String.downcase(word) + else + String.capitalize(word) + end + end) + |> Enum.join("") + end + + defp url(config), do: api_url(config) <> ":api_version/models/:model" + defp api_url(config), do: Keyword.fetch!(config, :api_url) + defp api_key(config), do: Keyword.fetch!(config, :api_key) + defp api_version(config), do: Keyword.fetch!(config, :api_version) + defp http_options(config), do: Keyword.fetch!(config, :http_options) + defp config(nil), do: config(Application.get_env(:instructor, :gemini, [])) + + defp config(base_config) do + default_config = [ + api_version: :v1beta, + api_url: "https://generativelanguage.googleapis.com/", + http_options: [receive_timeout: 60_000] + ] + + Keyword.merge(default_config, base_config) + end +end diff --git a/lib/instructor/adapters/llamacpp.ex b/lib/instructor/adapters/llamacpp.ex index 8c30214..a7e97d4 100644 --- a/lib/instructor/adapters/llamacpp.ex +++ b/lib/instructor/adapters/llamacpp.ex @@ -6,8 +6,7 @@ defmodule Instructor.Adapters.Llamacpp do You can read more about it here: https://github.com/ggerganov/llama.cpp/tree/master/examples/server """ - alias Instructor.JSONSchema - alias Instructor.GBNF + alias Instructor.Adapters @behaviour Instructor.Adapter @@ -22,7 +21,7 @@ defmodule Instructor.Adapters.Llamacpp do ## Examples iex> Instructor.chat_completion( - ...> model: "mistral-7b-instruct", + ...> model: "llama3.1-8b-instruct", ...> messages: [ ...> %{ role: "user", content: "Classify the following text: Hello I am a Nigerian prince and I would like to send you money!" }, ...> ], @@ -31,143 +30,28 @@ defmodule Instructor.Adapters.Llamacpp do ...> ) """ @impl true - def chat_completion(params, _config \\ nil) do - {response_model, _} = Keyword.pop!(params, :response_model) - {messages, _} = Keyword.pop!(params, :messages) + def chat_completion(params, config \\ nil) do + mode = params[:mode] - json_schema = JSONSchema.from_ecto_schema(response_model) - grammar = GBNF.from_json_schema(json_schema) - prompt = apply_chat_template(chat_template(), messages) - stream = Keyword.get(params, :stream, false) + params = + case mode do + :json_schema -> + update_in(params, [:response_format], fn response_format -> + %{ + type: "json_object", + schema: response_format.json_schema.schema + } + end) - if stream do - do_streaming_chat_completion(prompt, grammar) - else - do_chat_completion(prompt, grammar) - end - end - - defp do_streaming_chat_completion(prompt, grammar) do - pid = self() - - Stream.resource( - fn -> - Task.async(fn -> - Req.post(url(), - json: %{ - grammar: grammar, - prompt: prompt, - stream: true - }, - receive_timeout: 60_000, - into: fn {:data, data}, {req, resp} -> - send(pid, data) - {:cont, {req, resp}} - end - ) - - send(pid, :done) - end) - end, - fn acc -> - receive do - :done -> - {:halt, acc} - - "data: " <> data -> - data = Jason.decode!(data) - {[data], acc} - end - end, - fn acc -> acc end - ) - |> Stream.map(fn %{"content" => chunk} -> - to_openai_streaming_response(chunk) - end) - end - - defp to_openai_streaming_response(chunk) when is_binary(chunk) do - %{ - "choices" => [ - %{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}} - ] - } - end - - defp do_chat_completion(prompt, grammar) do - response = - Req.post(url(), - json: %{ - grammar: grammar, - prompt: prompt - }, - receive_timeout: 60_000 - ) - - case response do - {:ok, %{status: 200, body: %{"content" => params}}} -> - {:ok, to_openai_response(params)} - - {:ok, %{status: status}} -> - {:error, "Unexpected HTTP response code: #{status}"} + _ -> + raise "Unsupported mode: #{mode}" + end - {:error, reason} -> - {:error, reason} - end + default_config = [api_url: "http://localhost:8080", api_key: "llamacpp"] + config = Keyword.merge(default_config, config || []) + Adapters.OpenAI.chat_completion(params, config) end - defp to_openai_response(params) do - %{ - "choices" => [ - %{ - "message" => %{ - "tool_calls" => [ - %{"id" => "schema", "function" => %{"name" => "schema", "arguments" => params}} - ] - } - } - ] - } - end - - defp apply_chat_template(:mistral_instruct, messages) do - prompt = - messages - |> Enum.map_join("\n\n", fn - %{role: "assistant", content: content} -> "#{content} " - %{content: content} -> "[INST] #{content} [/INST]" - end) - - "#{prompt}" - end - - defp apply_chat_template(:tiny_llama, messages) do - prompt = - messages - |> Enum.map_join("\n\n", fn - %{role: role, content: content} -> "<|#{role}|>\n#{content} " - %{content: content} -> "<|user|>\n#{content} " - end) - - "#{prompt}" - end - - defp url() do - Keyword.get(config(), :url, "http://localhost:8080/completion") - end - - defp chat_template() do - Keyword.get(config(), :chat_template, :mistral_instruct) - end - - defp config() do - base_config = Application.get_env(:instructor, :llamacpp, []) - - default_config = [ - chat_template: :mistral_instruct, - api_url: "http://localhost:8080/completion" - ] - - Keyword.merge(default_config, base_config) - end + @impl true + defdelegate reask_messages(raw_response, params, config), to: Adapters.OpenAI end diff --git a/lib/instructor/adapters/ollama.ex b/lib/instructor/adapters/ollama.ex new file mode 100644 index 0000000..6f9145d --- /dev/null +++ b/lib/instructor/adapters/ollama.ex @@ -0,0 +1,25 @@ +defmodule Instructor.Adapters.Ollama do + @moduledoc """ + Ollama adapter for Instructor. + """ + @behaviour Instructor.Adapter + alias Instructor.Adapters + + @supported_modes [:json, :tools] + + @impl true + def chat_completion(params, config \\ nil) do + default_config = [api_url: "http://localhost:11434", api_key: "ollama"] + config = Keyword.merge(default_config, config || []) + mode = params[:mode] + + if mode not in @supported_modes do + raise "Unsupported mode: #{mode}" + end + + Adapters.OpenAI.chat_completion(params, config) + end + + @impl true + defdelegate reask_messages(raw_response, params, config), to: Adapters.OpenAI +end diff --git a/lib/instructor/adapters/openai.ex b/lib/instructor/adapters/openai.ex index 4b110a3..6adc7d0 100644 --- a/lib/instructor/adapters/openai.ex +++ b/lib/instructor/adapters/openai.ex @@ -3,27 +3,87 @@ defmodule Instructor.Adapters.OpenAI do Documentation for `Instructor.Adapters.OpenAI`. """ @behaviour Instructor.Adapter + @supported_modes [:tools, :json, :md_json, :json_schema] + + alias Instructor.JSONSchema + alias Instructor.SSEStreamParser @impl true - def chat_completion(params, config) do - config = if config, do: config, else: config() + def chat_completion(params, user_config \\ nil) do + config = config(user_config) # Peel off instructor only parameters {_, params} = Keyword.pop(params, :response_model) {_, params} = Keyword.pop(params, :validation_context) {_, params} = Keyword.pop(params, :max_retries) - {_, params} = Keyword.pop(params, :mode) + {mode, params} = Keyword.pop(params, :mode) stream = Keyword.get(params, :stream, false) params = Enum.into(params, %{}) + if mode not in @supported_modes do + raise "Unsupported OpenAI mode #{mode}. Supported modes: #{inspect(@supported_modes)}" + end + + params = + case params do + # OpenAI's json_schema mode doesn't support format or pattern attributes + %{"response_format" => %{"json_schema" => %{"schema" => _schema}}} -> + update_in(params, [:response_format, :json_schema, :schema], fn schema -> + JSONSchema.traverse_and_update(schema, fn + %{"type" => _} = x when is_map_key(x, "format") or is_map_key(x, "pattern") -> + Map.drop(x, ["format", "pattern"]) + + x -> + x + end) + end) + + _ -> + params + end + if stream do - do_streaming_chat_completion(params, config) + do_streaming_chat_completion(mode, params, config) else - do_chat_completion(params, config) + do_chat_completion(mode, params, config) end end - defp do_streaming_chat_completion(params, config) do + @impl true + def reask_messages(raw_response, params, _config) do + reask_messages_for_mode(params[:mode], raw_response) + end + + defp reask_messages_for_mode(:tools, %{ + "choices" => [ + %{ + "message" => + %{ + "tool_calls" => [ + %{"id" => tool_call_id, "function" => %{"name" => name, "arguments" => args}} = + function + ] + } = message + } + ] + }) do + [ + Map.put(message, "content", function |> Jason.encode!()) + |> Map.new(fn {k, v} -> {String.to_atom(k), v} end), + %{ + role: "tool", + tool_call_id: tool_call_id, + name: name, + content: args + } + ] + end + + defp reask_messages_for_mode(_mode, _raw_response) do + [] + end + + defp do_streaming_chat_completion(mode, params, config) do pid = self() options = http_options(config) @@ -35,22 +95,7 @@ defmodule Instructor.Adapters.OpenAI do auth_header(config), json: params, into: fn {:data, data}, {req, resp} -> - chunks = - data - |> String.split("\n") - |> Enum.filter(fn line -> - String.starts_with?(line, "data: {") - end) - |> Enum.map(fn line -> - line - |> String.replace_prefix("data: ", "") - |> Jason.decode!() - end) - - for chunk <- chunks do - send(pid, chunk) - end - + send(pid, data) {:cont, {req, resp}} end ]) @@ -73,18 +118,75 @@ defmodule Instructor.Adapters.OpenAI do end, fn task -> Task.await(task) end ) + |> SSEStreamParser.parse() + |> Stream.map(fn chunk -> parse_stream_chunk_for_mode(mode, chunk) end) end - defp do_chat_completion(params, config) do + defp do_chat_completion(mode, params, config) do options = Keyword.merge(http_options(config), [auth_header(config), json: params]) - case Req.post(url(config), options) do - {:ok, %{status: 200, body: body}} -> {:ok, body} - {:ok, %{status: status}} -> {:error, "Unexpected HTTP response code: #{status}"} - {:error, reason} -> {:error, reason} + with {:ok, %Req.Response{status: 200, body: body} = response} <- + Req.post(url(config), options), + {:ok, content} <- parse_response_for_mode(mode, body) do + {:ok, response, content} + else + {:ok, %Req.Response{status: status, body: body}} -> + {:error, "Unexpected HTTP response code: #{status}\n#{inspect(body)}"} + + e -> + e end end + defp parse_response_for_mode(:tools, %{ + "choices" => [ + %{"message" => %{"tool_calls" => [%{"function" => %{"arguments" => args}}]}} + ] + }), + do: Jason.decode(args) + + defp parse_response_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}), + do: Jason.decode(content) + + defp parse_response_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}), + do: Jason.decode(content) + + defp parse_response_for_mode(:json_schema, %{ + "choices" => [%{"message" => %{"content" => content}}] + }), + do: Jason.decode(content) + + defp parse_response_for_mode(mode, response) do + {:error, "Unsupported OpenAI mode #{mode} with response #{inspect(response)}"} + end + + defp parse_stream_chunk_for_mode(:md_json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}), + do: chunk + + defp parse_stream_chunk_for_mode(:json, %{"choices" => [%{"delta" => %{"content" => chunk}}]}), + do: chunk + + defp parse_stream_chunk_for_mode(:json_schema, %{ + "choices" => [%{"delta" => %{"content" => chunk}}] + }), + do: chunk + + defp parse_stream_chunk_for_mode(:tools, %{ + "choices" => [ + %{"delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}} + ] + }), + do: chunk + + defp parse_stream_chunk_for_mode(:tools, %{ + "choices" => [ + %{"delta" => %{"content" => chunk}} + ] + }), + do: chunk + + defp parse_stream_chunk_for_mode(_, %{"choices" => [%{"finish_reason" => "stop"}]}), do: "" + defp url(config), do: api_url(config) <> api_path(config) defp api_url(config), do: Keyword.fetch!(config, :api_url) defp api_path(config), do: Keyword.fetch!(config, :api_path) @@ -106,9 +208,9 @@ defmodule Instructor.Adapters.OpenAI do defp http_options(config), do: Keyword.fetch!(config, :http_options) - defp config() do - base_config = Application.get_env(:instructor, :openai, []) + defp config(nil), do: config(Application.get_env(:instructor, :openai, [])) + defp config(base_config) do default_config = [ api_url: "https://api.openai.com", api_path: "/v1/chat/completions", diff --git a/lib/instructor/adapters/vllm.ex b/lib/instructor/adapters/vllm.ex new file mode 100644 index 0000000..3ac6929 --- /dev/null +++ b/lib/instructor/adapters/vllm.ex @@ -0,0 +1,34 @@ +defmodule Instructor.Adapters.VLLM do + @moduledoc """ + VLLM adapter for Instructor. + """ + @behaviour Instructor.Adapter + alias Instructor.Adapters + + @supported_modes [:json_schema] + + @impl true + def chat_completion(params, config \\ nil) do + default_config = [api_url: "http://localhost:8000", api_key: "vllm"] + config = Keyword.merge(default_config, config || []) + mode = params[:mode] + + if mode not in @supported_modes do + raise "Unsupported mode: #{mode}" + end + + params = + case Enum.into(params, %{}) do + %{response_format: %{json_schema: %{schema: schema}}} -> + Keyword.put(params, :guided_json, schema) + + _ -> + params + end + + Adapters.OpenAI.chat_completion(params, config) + end + + @impl true + defdelegate reask_messages(raw_response, params, config), to: Adapters.OpenAI +end diff --git a/lib/instructor/json_schema.ex b/lib/instructor/json_schema.ex index afdc6ae..dbe3126 100644 --- a/lib/instructor/json_schema.ex +++ b/lib/instructor/json_schema.ex @@ -130,6 +130,7 @@ defmodule Instructor.JSONSchema do %{ title: title, type: "object", + additionalProperties: false, required: required, properties: properties, description: fetch_ecto_schema_doc(ecto_schema) || "" @@ -164,6 +165,7 @@ defmodule Instructor.JSONSchema do %{ title: "root", type: "object", + additionalProperties: false, required: required, properties: properties } @@ -200,30 +202,77 @@ defmodule Instructor.JSONSchema do defp find_all_values(_, _pred), do: [] - defp for_type(:id), do: %{type: "integer"} + defp for_type(:id), do: %{type: "integer", description: "Integer, e.g. 1"} defp for_type(:binary_id), do: %{type: "string"} - defp for_type(:integer), do: %{type: "integer"} - defp for_type(:float), do: %{type: "number", format: "float"} - defp for_type(:boolean), do: %{type: "boolean"} - defp for_type(:string), do: %{type: "string"} + defp for_type(:integer), do: %{type: "integer", description: "Integer, e.g. 1"} + defp for_type(:float), do: %{type: "number", description: "Float, e.g. 1.27", format: "float"} + defp for_type(:boolean), do: %{type: "boolean", description: "Boolean, e.g. true"} + defp for_type(:string), do: %{type: "string", description: "String, e.g. 'hello'"} # defp for_type(:binary), do: %{type: "unsupported"} defp for_type({:array, type}), do: %{type: "array", items: for_type(type)} - defp for_type(:map), do: %{type: "object", additionalProperties: %{}} + + defp for_type(:map), + do: %{ + type: "object", + properties: %{}, + additionalProperties: false, + description: "An object with arbitrary keys and values, e.g. { key: value }" + } defp for_type({:map, type}), - do: %{type: "object", additionalProperties: for_type(type)} + do: %{ + type: "object", + properties: %{}, + additionalProperties: for_type(type), + description: "An object with values of a type #{inspect(type)}, e.g. { key: value }" + } defp for_type(:decimal), do: %{type: "number", format: "float"} - defp for_type(:date), do: %{type: "string", format: "date"} - defp for_type(:time), do: %{type: "string", pattern: "^[0-9]{2}:?[0-9]{2}:?[0-9]{2}$"} + + defp for_type(:date), + do: %{type: "string", description: "ISO8601 Date, e.g. \"2024-07-20\"", format: "date"} + + defp for_type(:time), + do: %{ + type: "string", + description: "ISO8601 Time, e.g. \"12:00:00\"", + pattern: "^[0-9]{2}:?[0-9]{2}:?[0-9]{2}$" + } defp for_type(:time_usec), - do: %{type: "string", pattern: "^[0-9]{2}:?[0-9]{2}:?[0-9]{2}.[0-9]{6}$"} + do: %{ + type: "string", + description: "ISO8601 Time with microseconds, e.g. \"12:00:00.000000\"", + pattern: "^[0-9]{2}:?[0-9]{2}:?[0-9]{2}.[0-9]{6}$" + } + + defp for_type(:naive_datetime), + do: %{ + type: "string", + description: "ISO8601 DateTime, e.g. \"2024-07-20T12:00:00\"", + format: "date-time" + } - defp for_type(:naive_datetime), do: %{type: "string", format: "date-time"} - defp for_type(:naive_datetime_usec), do: %{type: "string", format: "date-time"} - defp for_type(:utc_datetime), do: %{type: "string", format: "date-time"} - defp for_type(:utc_datetime_usec), do: %{type: "string", format: "date-time"} + defp for_type(:naive_datetime_usec), + do: %{ + type: "string", + description: "ISO8601 DateTime with microseconds, e.g. \"2024-07-20T12:00:00.000000\"", + format: "date-time" + } + + defp for_type(:utc_datetime), + do: %{ + type: "string", + description: "ISO8601 DateTime, e.g. \"2024-07-20T12:00:00Z\"", + format: "date-time" + } + + defp for_type(:utc_datetime_usec), + do: %{ + type: "string", + description: "ISO8601 DateTime with microseconds, e.g. \"2024-07-20T12:00:00.000000Z\"", + format: "date-time" + } defp for_type( {:parameterized, {Ecto.Embedded, %Ecto.Embedded{cardinality: :many, related: related}}} @@ -280,7 +329,8 @@ defmodule Instructor.JSONSchema do %{ type: "object", required: required, - properties: properties + properties: properties, + additionalProperties: false } end @@ -298,4 +348,34 @@ defmodule Instructor.JSONSchema do raise "Unsupported type: #{inspect(mod)}, please implement `to_json_schema/0` via `use Instructor.EctoType`" end end + + @doc """ + Traverses a tree structure of maps and lists, allowing the user to update or remove elements. + + ## Parameters + - tree: The tree structure to traverse (can be a map, list, or any other type) + - fun: A function that takes an element and returns either: + - An updated element + - nil to remove the element + - The original element if no changes are needed + + ## Returns + The updated tree structure + """ + def traverse_and_update(tree, fun) when is_map(tree) do + tree + |> Enum.map(fn {k, v} -> {k, traverse_and_update(v, fun)} end) + |> Enum.filter(fn {_, v} -> v != nil end) + |> Enum.into(%{}) + |> fun.() + end + + def traverse_and_update(tree, fun) when is_list(tree) do + tree + |> Enum.map(fn elem -> traverse_and_update(elem, fun) end) + |> Enum.filter(&(&1 != nil)) + |> fun.() + end + + def traverse_and_update(tree, fun), do: fun.(tree) end diff --git a/lib/instructor/sse_stream_parser.ex b/lib/instructor/sse_stream_parser.ex new file mode 100644 index 0000000..582f5e2 --- /dev/null +++ b/lib/instructor/sse_stream_parser.ex @@ -0,0 +1,20 @@ +defmodule Instructor.SSEStreamParser do + @moduledoc false + + def parse(stream) do + stream + |> Stream.transform("", fn + data, acc -> + {chunks, [remaining]} = (acc <> data) |> String.split("\n", trim: false) |> Enum.split(-1) + {chunks, remaining} + end) + |> Stream.flat_map(fn chunk -> + chunk + |> String.split("data: ") + |> Enum.filter(&String.starts_with?(&1, "{")) + |> Enum.map(fn json_string -> + Jason.decode!(json_string) + end) + end) + end +end diff --git a/pages/cookbook/files/receipt.jpg b/pages/cookbook/files/receipt.jpg new file mode 100644 index 0000000..8a1909e Binary files /dev/null and b/pages/cookbook/files/receipt.jpg differ diff --git a/pages/cookbook/ocr-receipts-with-gpt4o.livemd b/pages/cookbook/ocr-receipts-with-gpt4o.livemd new file mode 100644 index 0000000..93a499d --- /dev/null +++ b/pages/cookbook/ocr-receipts-with-gpt4o.livemd @@ -0,0 +1,144 @@ + + +# Extracting data from images of Receipts + +```elixir +Mix.install( + [ + {:instructor, path: Path.expand("../../", __DIR__)}, + {:kino, "~> 0.12.3"} + ], + config: [ + instructor: [ + adapter: Instructor.Adapters.OpenAI, + openai: [ + api_key: System.fetch_env!("LB_OPENAI_API_KEY"), + http_options: [receive_timeout: 60_000, connect_options: [protocols: [:http2]]] + ] + ] + ] +) +``` + +## Our Object Model + +```elixir +image = Kino.FS.file_path("receipt.jpg") |> File.read!() +``` + + + +``` +<<255, 216, 255, 224, 0, 16, 74, 70, 73, 70, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 255, 226, 2, 40, 73, 67, + 67, 95, 80, 82, 79, 70, 73, 76, 69, 0, 1, 1, 0, 0, 2, 24, 0, 0, 0, 0, 2, 16, 0, 0, ...>> +``` + +We can look at the image of the receipt and outline the fields that we want to extract using an Ecto schema, as shown below. + +```elixir +defmodule Receipt do + use Ecto.Schema + use Instructor.Validator + + import Ecto.Changeset + + @primary_key false + embedded_schema do + field(:total, :decimal) + field(:subtotal, :decimal) + + embeds_many :items, Item do + field(:name, :string) + field(:price, :decimal) + field(:quantity, :integer) + end + end + + @impl true + def validate_changeset(changeset) do + changeset + |> validate_required([:total, :subtotal]) + |> validate_items_total() + end + + defp validate_items_total(changeset) do + items = get_field(changeset, :items) || [] + subtotal = get_field(changeset, :subtotal) + + items_total = Enum.reduce(items, Decimal.new(0), fn item, acc -> + item_total = Decimal.mult(item.price, Decimal.new(item.quantity)) + Decimal.add(acc, item_total) + end) + + if Decimal.equal?(items_total, subtotal) do + changeset + else + add_error(changeset, :subtotal, """ + Subtotal does not match sum of item prices. + Subtotal is #{subtotal} and items total is #{items_total}") + """) + end + end +end + +Kino.nothing() +``` + +## Validating the Subtotal + +Notice how we used the `validate_changeset` callback to check that the items we extract sum up to the subtotal on the receipt. This process allows us to gain confidence that the OCR application is functioning correctly. In fact, later on, we can even use the re-ask feature of the instructor to have the LLM correct any errors itself. + +## Calling the LLM + +```elixir +base64_image = "data:image/jpeg;base64," <> Base.encode64(image) + +{:ok, receipt} = Instructor.chat_completion( + model: "gpt-4o", + response_model: Receipt, + messages: [ + %{ + role: "user", + content: [ + %{ + type: "image_url", + image_url: %{url: base64_image} + }, + %{ + type: "text", + text: "Analyze the image and return the items in the receipt and the total amount." + } + ] + } + ] +) + +Kino.Layout.grid([Kino.Image.new(image, :jpeg), receipt], columns: 2, boxed: true) +``` + + + +``` +%Receipt{ + total: Decimal.new("107.6"), + subtotal: Decimal.new("107.6"), + items: [ + %Receipt.Item{id: "1", name: "Lorem ipsum", price: Decimal.new("9.2"), quantity: 1}, + %Receipt.Item{id: "2", name: "Lorem ipsum dolor sit", price: Decimal.new("19.2"), quantity: 1}, + %Receipt.Item{ + id: "3", + name: "Lorem ipsum dolor sit amet", + price: Decimal.new("15"), + quantity: 1 + }, + %Receipt.Item{id: "4", name: "Lorem ipsum", price: Decimal.new("15"), quantity: 1}, + %Receipt.Item{id: "5", name: "Lorem ipsum", price: Decimal.new("15"), quantity: 1}, + %Receipt.Item{id: "6", name: "Lorem ipsum dolor sit", price: Decimal.new("15"), quantity: 1}, + %Receipt.Item{id: "7", name: "Lorem ipsum", price: Decimal.new("19.2"), quantity: 1} + ] +} +``` + +Now, we simply using Instructor to call gpt-4o with the base64 encoded image and the response model, and we can get back the results. We can have confidence that the results returned match all of our validations, and we've reduced the effects of any hallucinations. + + diff --git a/pages/cookbook/streaming_ui.exs b/pages/cookbook/streaming_ui.exs new file mode 100644 index 0000000..abdb17c --- /dev/null +++ b/pages/cookbook/streaming_ui.exs @@ -0,0 +1,222 @@ +Mix.install([ + {:phoenix_playground, github: "gmile/phoenix_playground", branch: "handle-async-callback"}, + {:instructor, path: Path.expand("~/code/instructor_ex")} +]) + +defmodule Recipe do + use Ecto.Schema + + @doc """ + A recipe with ingredients and instructions. + + Ingredients should include quantity, unit, and name, for example: + + %{ + name: "salt", + quantity: 1, + unit: "cup" + } + + Instructions should include step and optional notes. + + %{ + step: "Mix the water and flour", + } + """ + @primary_key false + embedded_schema do + field(:name, :string) + field(:description, :string) + field(:prep_time, :integer) + field(:cook_time, :integer) + field(:total_time, :integer) + field(:servings, :integer) + + embeds_many(:ingredients, Ingredient) do + field(:name, :string) + field(:quantity, :decimal) + field(:unit, :string) + end + + embeds_many(:instructions, Instruction) do + field(:step, :string) + end + end +end + +defmodule StreamingUILive do + use Phoenix.LiveView + use Phoenix.Component + + alias Phoenix.LiveView.AsyncResult + + @instructor_config [ + adapter: Instructor.Adapters.OpenAI, + api_key: System.get_env("OPENAI_API_KEY") + ] + + def mount(_params, _session, socket) do + {:ok, + socket + |> assign(:prompt, "") + |> assign(:recipe, AsyncResult.ok(nil))} + end + + def handle_event("submit", %{"prompt" => prompt}, socket) do + pid = self() + + {:noreply, + socket + |> assign(:prompt, "") + |> start_async(:generate_recipe, fn -> + {:ok, recipe} = + Instructor.chat_completion( + [ + model: "gpt-4o-mini", + response_model: {:partial, Recipe}, + stream: true, + messages: [ + %{role: "user", content: prompt} + ] + ], + @instructor_config + ) + |> Stream.each(fn {_state, recipe} -> + send(pid, {:partial, Recipe, recipe}) + end) + |> Enum.to_list() + |> List.last() + + recipe + end)} + end + + def handle_info({:partial, Recipe, recipe}, socket) do + # Update the AsyncResult with the new partial recipe, but keep in existing loading state + {:noreply, assign(socket, :recipe, %{socket.assigns.recipe | result: recipe})} + end + + def handle_async(:generate_recipe, {:ok, recipe}, socket) do + {:noreply, assign(socket, :recipe, AsyncResult.ok(recipe))} + end + + def handle_async(:generate_recipe, {:error, error}, socket) do + {:noreply, assign(socket, :recipe, AsyncResult.error(error))} + end + + def search_unsplash(query) do + Req.get( + "https://api.unsplash.com/search/photos", + params: [query: "#{query}", per_page: 1, client_id: System.get_env("UNSPLASH_ACCESS_KEY")] + ) + |> case do + {:ok, %Req.Response{body: %{"results" => [%{"urls" => %{"regular" => image_url}}]}}} -> + image_url + + {:error, error} -> + IO.inspect(error) + nil + end + end + + def image_for_recipe(%Recipe{name: name} = _recipe) when not is_nil(name) do + cache_key = {:unsplash_image, name} + + case Process.get(cache_key) do + nil -> + image_url = search_unsplash(name) + Process.put(cache_key, image_url) + image_url + + cached_url -> + cached_url + end + end + + def image_for_recipe(_), do: "" + + def render(assigns) do + ~H""" + + +
+
+

✨ AI Recipe Generator

+
+ + +
+
+ +
+ <.recipe {assigns}/> +
+
+ """ + end + + defp recipe(assigns) do + ~H""" +
+
+
+ {@recipe.result.name} +

<%= @recipe.result.name %>

+
+ +
+

<%= @recipe.result.description %>

+ +
+ Prep: <%= @recipe.result.prep_time %> min + Cook: <%= @recipe.result.cook_time %> min + Total: <%= @recipe.result.total_time %> min + Serves: <%= @recipe.result.servings %> +
+ +
+

Ingredients

+
+ <%= for ingredient <- @recipe.result.ingredients do %> +
+ + <%= if ingredient.quantity != nil do %> + <%= to_string(ingredient.quantity) %> + <% end %> + + <%= ingredient.unit %> <%= ingredient.name %> +
+ <% end %> +
+
+ +
+

Instructions

+
    + <%= for {step, index} <- Enum.with_index(@recipe.result.instructions) do %> +
  1. + <%= index + 1 %> +

    <%= step.step %>

    +
  2. + <% end %> +
+
+
+
+ """ + end +end + +PhoenixPlayground.start(live: StreamingUILive, live_reload: false) diff --git a/test/instructor_test.exs b/test/instructor_test.exs index 8b44060..9723105 100644 --- a/test/instructor_test.exs +++ b/test/instructor_test.exs @@ -15,6 +15,20 @@ defmodule InstructorTest do :llamacpp -> Application.put_env(:instructor, :adapter, Instructor.Adapters.Llamacpp) + :ollama -> + Application.put_env(:instructor, :adapter, Instructor.Adapters.Ollama) + + :anthropic -> + Application.put_env(:instructor, :adapter, Instructor.Adapters.Anthropic) + + Application.put_env(:instructor, :anthropic, + api_key: System.fetch_env!("ANTHROPIC_API_KEY") + ) + + :gemini -> + Application.put_env(:instructor, :adapter, Instructor.Adapters.Gemini) + Application.put_env(:instructor, :gemini, api_key: System.fetch_env!("GOOGLE_API_KEY")) + :openai -> Application.put_env(:instructor, :adapter, Instructor.Adapters.OpenAI) Application.put_env(:instructor, :openai, api_key: System.fetch_env!("OPENAI_API_KEY")) @@ -36,9 +50,17 @@ defmodule InstructorTest do def mock_stream_response(_, _, _), do: nil - for adapter <- [:openai_mock, :openai, :llamacpp] do - # for adapter <- [:openai] do - describe "#{inspect(adapter)}" do + for {adapter, params} <- [ + {:mock_openai, [mode: :tools, model: "gpt-4o-mini"]}, + {:openai, [mode: :tools, model: "gpt-4o-mini"]}, + {:openai, [mode: :json, model: "gpt-4o-mini"]}, + {:openai, [mode: :json_schema, model: "gpt-4o-mini"]}, + {:llamacpp, [mode: :json_schema, model: "llama3.1-8b-instruct"]}, + {:gemini, [mode: :json_schema, model: "gemini-1.5-flash-latest"]}, + {:ollama, [mode: :tools, model: "llama3.1"]}, + {:anthropic, [mode: :tools, model: "claude-3-5-sonnet-20240620", max_tokens: 1024]} + ] do + describe "#{inspect(adapter)} #{params[:mode]} #{params[:model]}" do @tag adapter: adapter test "schemaless ecto" do expected = %{name: "George Washington", birth_date: ~D[1732-02-22]} @@ -46,11 +68,12 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - response_model: %{name: :string, birth_date: :date}, - messages: [ - %{role: "user", content: "Who was the first president of the USA?"} - ] + Keyword.merge(unquote(params), + response_model: %{name: :string, birth_date: :date}, + messages: [ + %{role: "user", content: "Who was the first president of the USA?"} + ] + ) ) assert {:ok, %{name: name, birth_date: birth_date}} = result @@ -75,15 +98,16 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - response_model: SpamPrediction, - messages: [ - %{ - role: "user", - content: - "Classify the following text: Hello, I am a Nigerian prince and I would like to give you $1,000,000." - } - ] + Keyword.merge(unquote(params), + response_model: SpamPrediction, + messages: [ + %{ + role: "user", + content: + "Classify the following text: Hello, I am a Nigerian prince and I would like to give you $1,000,000." + } + ] + ) ) assert {:ok, %SpamPrediction{class: :spam, score: score}} = result @@ -103,8 +127,8 @@ defmodule InstructorTest do field(:string, :string) # field(:binary, :binary) field(:array, {:array, :string}) - field(:map, :map) - field(:map_two, {:map, :string}) + field(:nested_object, :map) + field(:nested_object_two, {:map, :string}) field(:decimal, :decimal) field(:date, :date) field(:time, :time) @@ -125,8 +149,8 @@ defmodule InstructorTest do boolean: true, string: "string", array: ["array"], - map: %{"map" => "map"}, - map_two: %{"map_two" => "map_two"}, + nested_object: %{"map" => "map"}, + nested_object_two: %{"map_two" => "map_two"}, decimal: 1.0, date: "2021-08-01", time: "12:00:00", @@ -141,15 +165,36 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - response_model: AllEctoTypes, - messages: [ - %{ - role: "user", - content: - "What are the types of the following fields: binary_id, integer, float, boolean, string, array, map, map_two, decimal, date, time, time_usec, naive_datetime, naive_datetime_usec, utc_datetime, utc_datetime_usec?" - } - ] + Keyword.merge(unquote(params), + response_model: AllEctoTypes, + messages: [ + %{ + role: "user", + content: """ + Return the exact object below, nothing else. + + { + "integer": 1, + "date": "2021-08-01", + "float": 1.0, + "time": "12:00:00", + "string": "string", + "boolean": true, + "array": [ "array_value" ], + "decimal": 1.0, + "binary_id": "binary_id", + "naive_datetime": "2021-08-01T12:00:00", + "naive_datetime_usec": "2021-08-01T12:00:00.000000", + "utc_datetime": "2021-08-01T12:00:00Z", + "utc_datetime_usec": "2021-08-01T12:00:00.000000Z", + "time_usec": "12:00:00.000000", + "nested_object": { "key": "value" }, + "nested_object_two": { "key_two": "value_two" } + } + """ + } + ] + ) ) assert {:ok, @@ -160,8 +205,8 @@ defmodule InstructorTest do boolean: boolean, string: string, array: array, - map: map, - map_two: map_two, + nested_object: nested_object, + nested_object_two: nested_object_two, decimal: decimal, date: date, time: time, @@ -178,8 +223,8 @@ defmodule InstructorTest do assert is_boolean(boolean) assert is_binary(string) assert is_list(array) - assert is_map(map) - assert is_map(map_two) + assert is_map(nested_object) + assert is_map(nested_object_two) assert %Decimal{} = decimal assert %Date{} = date assert %Time{} = time @@ -212,12 +257,13 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - stream: true, - response_model: {:array, President}, - messages: [ - %{role: "user", content: "What are the first 3 presidents of the United States?"} - ] + Keyword.merge(unquote(params), + stream: true, + response_model: {:array, President}, + messages: [ + %{role: "user", content: "What are the first 3 presidents of the United States?"} + ] + ) ) assert TestHelpers.is_stream?(result) @@ -237,12 +283,13 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - stream: true, - response_model: {:partial, President}, - messages: [ - %{role: "user", content: "Who was the first president of the United States"} - ] + Keyword.merge(unquote(params), + stream: true, + response_model: {:partial, President}, + messages: [ + %{role: "user", content: "Who was the first president of the United States"} + ] + ) ) assert TestHelpers.is_stream?(result) @@ -272,12 +319,13 @@ defmodule InstructorTest do result = Instructor.chat_completion( - model: "gpt-3.5-turbo", - stream: true, - response_model: {:partial, {:array, President}}, - messages: [ - %{role: "user", content: "Who were the first 2 presidents of the United States"} - ] + Keyword.merge(unquote(params), + stream: true, + response_model: {:partial, {:array, President}}, + messages: [ + %{role: "user", content: "Who were the first 2 presidents of the United States"} + ] + ) ) assert TestHelpers.is_stream?(result) @@ -345,6 +393,7 @@ defmodule InstructorTest do @tag adapter: :openai_mock test "retry upto n times" do TestHelpers.mock_openai_response(:tools, %{wrong_field: "foobar"}) + TestHelpers.mock_openai_reask_messages() TestHelpers.mock_openai_response(:tools, %{wrong_field: "foobar"}) result = @@ -360,7 +409,9 @@ defmodule InstructorTest do assert {:error, %Ecto.Changeset{valid?: false}} = result TestHelpers.mock_openai_response(:tools, %{wrong_field: "foobar"}) + TestHelpers.mock_openai_reask_messages() TestHelpers.mock_openai_response(:tools, %{field: 123}) + TestHelpers.mock_openai_reask_messages() TestHelpers.mock_openai_response(:tools, %{field: "foobar"}) result = diff --git a/test/sse_stream_parser_test.exs b/test/sse_stream_parser_test.exs new file mode 100644 index 0000000..17b9bc5 --- /dev/null +++ b/test/sse_stream_parser_test.exs @@ -0,0 +1,35 @@ +defmodule Instructor.SSEStreamParserTest do + use ExUnit.Case + + alias Instructor.SSEStreamParser + + test "parses a stream" do + tokens = [ + "data: { \"number\": 1 }\n", + "data: { \"number\": 2 }\n", + "data: { \"number\": 3 }\n", + ] + + + assert SSEStreamParser.parse(tokens) |> Enum.to_list() == [ + %{"number" => 1}, + %{"number" => 2}, + %{"number" => 3} + ] + end + + test "parses a stream where lines are split across chunks" do + tokens = [ + "data: { \"number\": 1 }\n", + "data: { \"number\":", + " 2 }\n", + "data: { \"number\": 3 }\n", + ] + + assert SSEStreamParser.parse(tokens) |> Enum.to_list() == [ + %{"number" => 1}, + %{"number" => 2}, + %{"number" => 3} + ] + end +end diff --git a/test/support/test_helpers.ex b/test/support/test_helpers.ex index 44848c8..3deeb7e 100644 --- a/test/support/test_helpers.ex +++ b/test/support/test_helpers.ex @@ -37,7 +37,7 @@ defmodule Instructor.TestHelpers do "object" => "chat.completion", "created" => 1_704_579_055, "system_fingerprint" => nil - }} + }, result} end) end @@ -68,7 +68,8 @@ defmodule Instructor.TestHelpers do "object" => "chat.completion", "created" => 1_704_579_055, "system_fingerprint" => nil - } + }, + result } end) end @@ -79,40 +80,9 @@ defmodule Instructor.TestHelpers do |> String.graphemes() |> Enum.chunk_every(12) |> Enum.map(fn chunk -> - chunk = Enum.join(chunk, "") - - %{ - "choices" => [ - %{ - "delta" => %{"tool_calls" => [%{"function" => %{"arguments" => chunk}}]}, - "finish_reason" => nil, - "index" => 0, - "logprobs" => nil - } - ], - "created" => 1_704_666_072, - "id" => "chatcmpl-8eVo0dIB83q0IzSvrZeO4tM1CO9y8", - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion.chunk", - "system_fingerprint" => nil - } + Enum.join(chunk, "") end) - chunks = - chunks ++ - [ - %{ - "choices" => [ - %{"delta" => %{}, "finish_reason" => "stop", "index" => 0, "logprobs" => nil} - ], - "created" => 1_704_666_072, - "id" => "chatcmpl-8eVo0dIB83q0IzSvrZeO4tM1CO9y8", - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion.chunk", - "system_fingerprint" => nil - } - ] - InstructorTest.MockOpenAI |> expect(:chat_completion, fn _params, _config -> chunks @@ -125,46 +95,22 @@ defmodule Instructor.TestHelpers do |> String.graphemes() |> Enum.chunk_every(12) |> Enum.map(fn chunk -> - chunk = Enum.join(chunk, "") - - %{ - "choices" => [ - %{ - "delta" => %{"content" => chunk}, - "finish_reason" => nil, - "index" => 0, - "logprobs" => nil - } - ], - "created" => 1_704_666_072, - "id" => "chatcmpl-8eVo0dIB83q0IzSvrZeO4tM1CO9y8", - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion.chunk", - "system_fingerprint" => nil - } + Enum.join(chunk, "") end) - chunks = - chunks ++ - [ - %{ - "choices" => [ - %{"delta" => %{}, "finish_reason" => "stop", "index" => 0, "logprobs" => nil} - ], - "created" => 1_704_666_072, - "id" => "chatcmpl-8eVo0dIB83q0IzSvrZeO4tM1CO9y8", - "model" => "gpt-3.5-turbo-0613", - "object" => "chat.completion.chunk", - "system_fingerprint" => nil - } - ] - InstructorTest.MockOpenAI |> expect(:chat_completion, fn _params, _config -> chunks end) end + def mock_openai_reask_messages() do + InstructorTest.MockOpenAI + |> expect(:reask_messages, fn _raw_response, _params, _config -> + [] + end) + end + def is_stream?(variable) do case variable do %Stream{} -> diff --git a/test/test_helper.exs b/test/test_helper.exs index 7de2ea5..38d1c54 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1,6 +1,10 @@ Mox.defmock(InstructorTest.MockOpenAI, for: Instructor.Adapter) -# Exclude the unmocked tests by default +# Exclude the unmocked tests by default, to run them use: +# +# mix test --only adapter:llamacpp +# mix test --only adapter:openai +# ExUnit.configure(exclude: [adapter: :llamacpp, adapter: :openai]) ExUnit.start()