From 24c189e9a9f663de32946ab46befb314d3c2eb61 Mon Sep 17 00:00:00 2001 From: Dallin Romney Date: Fri, 10 Jan 2025 10:11:31 -0800 Subject: [PATCH 1/4] ollama tools kickoff --- .vscode/launch.json | 2 +- core/llm/autodetect.ts | 15 ++- core/llm/constructMessages.ts | 23 ++--- core/llm/llms/Ollama.ts | 167 +++++++++++++++++++------------ extensions/intellij/CHANGELOG.md | 2 +- 5 files changed, 126 insertions(+), 83 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index dfe474f40b..15720f315f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -23,7 +23,7 @@ "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "staging" for the preview deployment "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "local" for entirely local development of control plane/proxy - // "CONTROL_PLANE_ENV": "staging" + // "CONTROL_PLANE_ENV": "local" } }, { diff --git a/core/llm/autodetect.ts b/core/llm/autodetect.ts index 2a9986e6cb..80c2e095f9 100644 --- a/core/llm/autodetect.ts +++ b/core/llm/autodetect.ts @@ -86,12 +86,17 @@ const MODEL_SUPPORTS_IMAGES: string[] = [ "llama3.2", ]; +const PROVIDER_TOOL_SUPPORT: Record = { + anthropic: ["claude-3-5", "claude-3.5"], + ollama: ["llama3.1"], +}; + function modelSupportsTools(modelName: string, provider: string) { - return ( - provider === "anthropic" && - modelName.includes("claude") && - (modelName.includes("3-5") || modelName.includes("3.5")) - ); + const providerSupport = PROVIDER_TOOL_SUPPORT[provider]; + if (!providerSupport) { + return false; + } + return !!providerSupport.some((part) => modelName.includes(part)); } function modelSupportsImages( diff --git a/core/llm/constructMessages.ts b/core/llm/constructMessages.ts index 24f44279da..190ede74ad 100644 --- a/core/llm/constructMessages.ts +++ b/core/llm/constructMessages.ts @@ -71,20 +71,17 @@ function constructSystemPrompt( provider: string, useTools: boolean, ): string | null { - if ( - useTools && - CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family)) - ) { - return SYSTEM_MESSAGE + "\n\n" + TOOL_USE_RULES; - } else if ( - CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family)) - ) { - return SYSTEM_MESSAGE; - } else if (useTools && modelSupportsTools(model, provider)) { - return TOOL_USE_RULES; + let systemMessage = ""; + if(CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family))){ + systemMessage = SYSTEM_MESSAGE; } - - return null; + if(useTools && modelSupportsTools(model, provider)){ + if(systemMessage){ + systemMessage += "\n\n"; + } + systemMessage += TOOL_USE_RULES; + } + return systemMessage || null; } const CANCELED_TOOL_CALL_MESSAGE = diff --git a/core/llm/llms/Ollama.ts b/core/llm/llms/Ollama.ts index 3e9af9d38a..f13a46ec88 100644 --- a/core/llm/llms/Ollama.ts +++ b/core/llm/llms/Ollama.ts @@ -1,12 +1,17 @@ -import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; +import { + ChatMessage, + CompletionOptions, + LLMOptions, + Tool, +} from "../../index.js"; import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamResponse } from "../stream.js"; -type OllamaChatMessage = ChatMessage & { images?: string[] }; +type OllamaChatMessage = ChatMessage & { images?: string[] | null }; // See https://github.com/ollama/ollama/blob/main/docs/modelfile.md for details on each parameter -interface ModelFileParams { +interface OllamaModelFileParams { mirostat?: number; mirostat_eta?: number; mirostat_tau?: number; @@ -21,23 +26,35 @@ interface ModelFileParams { top_k?: number; top_p?: number; min_p?: number; - // deprecated? + + // deprecated or not directly supported here: num_thread?: number; use_mmap?: boolean; num_gqa?: number; num_gpu?: number; + num_keep?: number; + typical_p?: number; + presence_penalty?: number; + frequency_penalty?: number; + penalize_newline?: boolean; + numa?: boolean; + num_batch?: number; + main_gpu?: number; + low_vram?: boolean; + vocab_only?: boolean; + use_mlock?: boolean; } // See https://github.com/ollama/ollama/blob/main/docs/api.md -interface BaseOptions { +interface OllamaBaseOptions { model: string; // the model name - options?: ModelFileParams; // additional model parameters listed in the documentation for the Modelfile such as temperature + options?: OllamaModelFileParams; // additional model parameters listed in the documentation for the Modelfile such as temperature format?: "json"; // the format to return a response in. Currently, the only accepted value is json stream?: boolean; // if false the response will be returned as a single response object, rather than a stream of objects keep_alive?: number; // controls how long the model will stay loaded into memory following the request (default: 5m) } -interface GenerateOptions extends BaseOptions { +interface OllamaGenerateOptions extends OllamaBaseOptions { prompt: string; // the prompt to generate a response for suffix?: string; // the text after the model response images?: string[]; // a list of base64-encoded images (for multimodal models such as llava) @@ -47,8 +64,18 @@ interface GenerateOptions extends BaseOptions { raw?: boolean; // if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API } -interface ChatOptions extends BaseOptions { +interface OllamaTool { + type: "function"; + function: { + name: string; + description: string; + parameters: any; + }; +} + +interface OllamaChatOptions extends OllamaBaseOptions { messages: OllamaChatMessage[]; // the messages of the chat, this can be used to keep a chat memory + tools?: OllamaTool[]; // the tools of the chat, this can be used to keep a tool memory // Not supported yet - tools: tools for the model to use if supported. Requires stream to be set to false // And correspondingly, tool calls in OllamaChatMessage } @@ -130,56 +157,59 @@ class Ollama extends BaseLLM { }); } + // Map of "continue model name" to Ollama actual model name + private modelMap: Record = { + "mistral-7b": "mistral:7b", + "mixtral-8x7b": "mixtral:8x7b", + "llama2-7b": "llama2:7b", + "llama2-13b": "llama2:13b", + "codellama-7b": "codellama:7b", + "codellama-13b": "codellama:13b", + "codellama-34b": "codellama:34b", + "codellama-70b": "codellama:70b", + "llama3-8b": "llama3:8b", + "llama3-70b": "llama3:70b", + "llama3.1-8b": "llama3.1:8b", + "llama3.1-70b": "llama3.1:70b", + "llama3.1-405b": "llama3.1:405b", + "llama3.2-1b": "llama3.2:1b", + "llama3.2-3b": "llama3.2:3b", + "llama3.2-11b": "llama3.2:11b", + "llama3.2-90b": "llama3.2:90b", + "phi-2": "phi:2.7b", + "phind-codellama-34b": "phind-codellama:34b-v2", + "qwen2.5-coder-0.5b": "qwen2.5-coder:0.5b", + "qwen2.5-coder-1.5b": "qwen2.5-coder:1.5b", + "qwen2.5-coder-3b": "qwen2.5-coder:3b", + "qwen2.5-coder-7b": "qwen2.5-coder:7b", + "qwen2.5-coder-14b": "qwen2.5-coder:14b", + "qwen2.5-coder-32b": "qwen2.5-coder:32b", + "wizardcoder-7b": "wizardcoder:7b-python", + "wizardcoder-13b": "wizardcoder:13b-python", + "wizardcoder-34b": "wizardcoder:34b-python", + "zephyr-7b": "zephyr:7b", + "codeup-13b": "codeup:13b", + "deepseek-1b": "deepseek-coder:1.3b", + "deepseek-7b": "deepseek-coder:6.7b", + "deepseek-33b": "deepseek-coder:33b", + "neural-chat-7b": "neural-chat:7b-v3.3", + "starcoder-1b": "starcoder:1b", + "starcoder-3b": "starcoder:3b", + "starcoder2-3b": "starcoder2:3b", + "stable-code-3b": "stable-code:3b", + "granite-code-3b": "granite-code:3b", + "granite-code-8b": "granite-code:8b", + "granite-code-20b": "granite-code:20b", + "granite-code-34b": "granite-code:34b", + }; + private _getModel() { - return ( - { - "mistral-7b": "mistral:7b", - "mixtral-8x7b": "mixtral:8x7b", - "llama2-7b": "llama2:7b", - "llama2-13b": "llama2:13b", - "codellama-7b": "codellama:7b", - "codellama-13b": "codellama:13b", - "codellama-34b": "codellama:34b", - "codellama-70b": "codellama:70b", - "llama3-8b": "llama3:8b", - "llama3-70b": "llama3:70b", - "llama3.1-8b": "llama3.1:8b", - "llama3.1-70b": "llama3.1:70b", - "llama3.1-405b": "llama3.1:405b", - "llama3.2-1b": "llama3.2:1b", - "llama3.2-3b": "llama3.2:3b", - "llama3.2-11b": "llama3.2:11b", - "llama3.2-90b": "llama3.2:90b", - "phi-2": "phi:2.7b", - "phind-codellama-34b": "phind-codellama:34b-v2", - "qwen2.5-coder-0.5b": "qwen2.5-coder:0.5b", - "qwen2.5-coder-1.5b": "qwen2.5-coder:1.5b", - "qwen2.5-coder-3b": "qwen2.5-coder:3b", - "qwen2.5-coder-7b": "qwen2.5-coder:7b", - "qwen2.5-coder-14b": "qwen2.5-coder:14b", - "qwen2.5-coder-32b": "qwen2.5-coder:32b", - "wizardcoder-7b": "wizardcoder:7b-python", - "wizardcoder-13b": "wizardcoder:13b-python", - "wizardcoder-34b": "wizardcoder:34b-python", - "zephyr-7b": "zephyr:7b", - "codeup-13b": "codeup:13b", - "deepseek-1b": "deepseek-coder:1.3b", - "deepseek-7b": "deepseek-coder:6.7b", - "deepseek-33b": "deepseek-coder:33b", - "neural-chat-7b": "neural-chat:7b-v3.3", - "starcoder-1b": "starcoder:1b", - "starcoder-3b": "starcoder:3b", - "starcoder2-3b": "starcoder2:3b", - "stable-code-3b": "stable-code:3b", - "granite-code-3b": "granite-code:3b", - "granite-code-8b": "granite-code:8b", - "granite-code-20b": "granite-code:20b", - "granite-code-34b": "granite-code:34b", - }[this.model] ?? this.model - ); + return this.modelMap[this.model] ?? this.model; } - private _getModelFileParams(options: CompletionOptions): ModelFileParams { + private _getModelFileParams( + options: CompletionOptions, + ): OllamaModelFileParams { return { temperature: options.temperature, top_p: options.topP, @@ -194,10 +224,16 @@ class Ollama extends BaseLLM { }; } - private _convertMessage(message: ChatMessage) { - if (message.role === "tool") { - return null; - } + private _convertMessage(message: ChatMessage): OllamaChatMessage { + // if (message.role === "tool") { + // return null; + // } + // if (message.role === "tool") { + // return { + // role: "tool", + + // }; + // } if (typeof message.content === "string") { return message; @@ -222,22 +258,27 @@ class Ollama extends BaseLLM { private _getChatOptions( options: CompletionOptions, messages: ChatMessage[], - ): ChatOptions { - return { + ): OllamaChatOptions { + const chatOptions: OllamaChatOptions = { model: this._getModel(), - messages: messages.map(this._convertMessage).filter(Boolean) as any, + messages: messages.map(this._convertMessage).filter(Boolean), options: this._getModelFileParams(options), keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes stream: options.stream, // format: options.format, // Not currently in base completion options }; + if (options.tools?.length) { + chatOptions.tools = options.tools; + chatOptions.stream = false; // Cannot set stream = true for tools calls + } + return chatOptions; } private _getGenerateOptions( options: CompletionOptions, prompt: string, suffix?: string, - ): GenerateOptions { + ): OllamaGenerateOptions { return { model: this._getModel(), prompt, diff --git a/extensions/intellij/CHANGELOG.md b/extensions/intellij/CHANGELOG.md index 7745d5c91e..391d722c72 100644 --- a/extensions/intellij/CHANGELOG.md +++ b/extensions/intellij/CHANGELOG.md @@ -6,7 +6,7 @@ and is generated by [Changie](https://github.com/miniscruff/changie). Pre-release Changes -### Fixed +### Fixed * Off-screen rendering to solve white flash on load and lack of changing cursor type * OSR-related fixes for non-Mac users * Fixes for inline edit in JetBrains From d0909fb11ab3d5041031cf7a1246c68196a43f0e Mon Sep 17 00:00:00 2001 From: Dallin Romney Date: Fri, 10 Jan 2025 20:13:49 -0800 Subject: [PATCH 2/4] ollama tool support and docs --- core/index.d.ts | 1 - core/llm/autodetect.ts | 8 +- core/llm/constructMessages.ts | 6 +- core/llm/llms/Ollama.ts | 245 +++++++++++------- core/llm/toolSupport.ts | 46 ++++ docs/docs/customize/tools.md | 7 +- extensions/vscode/config_schema.json | 5 - .../components/mainInput/ContinueInputBox.tsx | 6 +- gui/src/context/IdeMessenger.ts | 1 + gui/src/redux/slices/sessionSlice.ts | 108 +++++--- gui/src/redux/store.ts | 1 + gui/src/redux/thunks/streamNormalInput.ts | 15 +- 12 files changed, 290 insertions(+), 159 deletions(-) create mode 100644 core/llm/toolSupport.ts diff --git a/core/index.d.ts b/core/index.d.ts index e4f3ff143d..8c859c8438 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -1066,7 +1066,6 @@ export interface ExperimentalConfig { * This is needed to crawl a large number of documentation sites that are dynamically rendered. */ useChromiumForDocsCrawling?: boolean; - useTools?: boolean; modelContextProtocolServers?: MCPOptions[]; } diff --git a/core/llm/autodetect.ts b/core/llm/autodetect.ts index 80c2e095f9..8232f82887 100644 --- a/core/llm/autodetect.ts +++ b/core/llm/autodetect.ts @@ -35,6 +35,7 @@ import { xWinCoderEditPrompt, zephyrEditPrompt, } from "./templates/edit.js"; +import { PROVIDER_TOOL_SUPPORT } from "./toolSupport.js"; const PROVIDER_HANDLES_TEMPLATING: string[] = [ "lmstudio", @@ -86,17 +87,12 @@ const MODEL_SUPPORTS_IMAGES: string[] = [ "llama3.2", ]; -const PROVIDER_TOOL_SUPPORT: Record = { - anthropic: ["claude-3-5", "claude-3.5"], - ollama: ["llama3.1"], -}; - function modelSupportsTools(modelName: string, provider: string) { const providerSupport = PROVIDER_TOOL_SUPPORT[provider]; if (!providerSupport) { return false; } - return !!providerSupport.some((part) => modelName.includes(part)); + return providerSupport(modelName) ?? false; } function modelSupportsImages( diff --git a/core/llm/constructMessages.ts b/core/llm/constructMessages.ts index 190ede74ad..5b0d1b60b0 100644 --- a/core/llm/constructMessages.ts +++ b/core/llm/constructMessages.ts @@ -72,11 +72,11 @@ function constructSystemPrompt( useTools: boolean, ): string | null { let systemMessage = ""; - if(CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family))){ + if (CUSTOM_SYS_MSG_MODEL_FAMILIES.some((family) => model.includes(family))) { systemMessage = SYSTEM_MESSAGE; } - if(useTools && modelSupportsTools(model, provider)){ - if(systemMessage){ + if (useTools && modelSupportsTools(model, provider)) { + if (systemMessage) { systemMessage += "\n\n"; } systemMessage += TOOL_USE_RULES; diff --git a/core/llm/llms/Ollama.ts b/core/llm/llms/Ollama.ts index f13a46ec88..69db6179b8 100644 --- a/core/llm/llms/Ollama.ts +++ b/core/llm/llms/Ollama.ts @@ -1,14 +1,21 @@ -import { - ChatMessage, - CompletionOptions, - LLMOptions, - Tool, -} from "../../index.js"; +import { JSONSchema, JSONSchemaType } from "openai/lib/jsonschema.mjs"; + +import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; import { renderChatMessage } from "../../util/messageContent.js"; import { BaseLLM } from "../index.js"; import { streamResponse } from "../stream.js"; -type OllamaChatMessage = ChatMessage & { images?: string[] | null }; +type OllamaChatMessage = { + role: "tool" | "user" | "assistant" | "system"; + content: string; + images?: string[] | null; + tool_calls?: { + function: { + name: string; + arguments: Record; + }; + }[]; +}; // See https://github.com/ollama/ollama/blob/main/docs/modelfile.md for details on each parameter interface OllamaModelFileParams { @@ -54,7 +61,7 @@ interface OllamaBaseOptions { keep_alive?: number; // controls how long the model will stay loaded into memory following the request (default: 5m) } -interface OllamaGenerateOptions extends OllamaBaseOptions { +interface OllamaRawOptions extends OllamaBaseOptions { prompt: string; // the prompt to generate a response for suffix?: string; // the text after the model response images?: string[]; // a list of base64-encoded images (for multimodal models such as llava) @@ -64,15 +71,6 @@ interface OllamaGenerateOptions extends OllamaBaseOptions { raw?: boolean; // if true no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your request to the API } -interface OllamaTool { - type: "function"; - function: { - name: string; - description: string; - parameters: any; - }; -} - interface OllamaChatOptions extends OllamaBaseOptions { messages: OllamaChatMessage[]; // the messages of the chat, this can be used to keep a chat memory tools?: OllamaTool[]; // the tools of the chat, this can be used to keep a tool memory @@ -80,6 +78,51 @@ interface OllamaChatOptions extends OllamaBaseOptions { // And correspondingly, tool calls in OllamaChatMessage } +type OllamaBaseResponse = { + model: string; + created_at: string; +} & ( + | { + done: false; + } + | { + done: true; + done_reason: string; + total_duration: number; // Time spent generating the response in nanoseconds + load_duration: number; // Time spent loading the model in nanoseconds + prompt_eval_count: number; // Number of tokens in the prompt + prompt_eval_duration: number; // Time spent evaluating the prompt in nanoseconds + eval_count: number; // Number of tokens in the response + eval_duration: number; // Time spent generating the response in nanoseconds + context: number[]; // An encoding of the conversation used in this response; can be sent in the next request to keep conversational memory + } +); + +type OllamaErrorResponse = { + error: string; +}; + +type OllamaRawResponse = + | OllamaErrorResponse + | (OllamaBaseResponse & { + response: string; // the generated response + }); + +type OllamaChatResponse = + | OllamaErrorResponse + | (OllamaBaseResponse & { + message: OllamaChatMessage; + }); + +interface OllamaTool { + type: "function"; + function: { + name: string; + description?: string; + parameters?: JSONSchema; + }; +} + class Ollama extends BaseLLM { static providerName = "ollama"; static defaultOptions: Partial = { @@ -224,61 +267,38 @@ class Ollama extends BaseLLM { }; } - private _convertMessage(message: ChatMessage): OllamaChatMessage { - // if (message.role === "tool") { - // return null; - // } - // if (message.role === "tool") { - // return { - // role: "tool", - - // }; - // } + private _convertToOllamaMessage(message: ChatMessage): OllamaChatMessage { + const ollamaMessage: OllamaChatMessage = { + role: message.role, + content: "", + }; if (typeof message.content === "string") { - return message; - } - const images: string[] = []; - message.content.forEach((part) => { - if (part.type === "imageUrl" && part.imageUrl) { - const image = part.imageUrl?.url.split(",").at(-1); - if (image) { - images.push(image); + ollamaMessage.content = message.content; + } else { + ollamaMessage.content = renderChatMessage(message); + const images: string[] = []; + message.content.forEach((part) => { + if (part.type === "imageUrl" && part.imageUrl) { + const image = part.imageUrl?.url.split(",").at(-1); + if (image) { + images.push(image); + } } + }); + if (images.length > 0) { + ollamaMessage.images = images; } - }); - - return { - role: message.role, - content: renderChatMessage(message), - images, - }; - } - - private _getChatOptions( - options: CompletionOptions, - messages: ChatMessage[], - ): OllamaChatOptions { - const chatOptions: OllamaChatOptions = { - model: this._getModel(), - messages: messages.map(this._convertMessage).filter(Boolean), - options: this._getModelFileParams(options), - keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes - stream: options.stream, - // format: options.format, // Not currently in base completion options - }; - if (options.tools?.length) { - chatOptions.tools = options.tools; - chatOptions.stream = false; // Cannot set stream = true for tools calls } - return chatOptions; + + return ollamaMessage; } private _getGenerateOptions( options: CompletionOptions, prompt: string, suffix?: string, - ): OllamaGenerateOptions { + ): OllamaRawOptions { return { model: this._getModel(), prompt, @@ -327,12 +347,11 @@ class Ollama extends BaseLLM { const chunk = chunks[i]; if (chunk.trim() !== "") { try { - const j = JSON.parse(chunk); - if ("response" in j) { - yield j.response; - } else if ("error" in j) { + const j = JSON.parse(chunk) as OllamaRawResponse; + if ("error" in j) { throw new Error(j.error); } + yield j.response; } catch (e) { throw new Error(`Error parsing Ollama response: ${e} ${chunk}`); } @@ -346,39 +365,91 @@ class Ollama extends BaseLLM { signal: AbortSignal, options: CompletionOptions, ): AsyncGenerator { + const chatOptions: OllamaChatOptions = { + model: this._getModel(), + messages: messages.map(this._convertToOllamaMessage), + options: this._getModelFileParams(options), + keep_alive: options.keepAlive ?? 60 * 30, // 30 minutes + stream: options.stream, + // format: options.format, // Not currently in base completion options + }; + if (options.tools?.length) { + chatOptions.tools = options.tools.map((tool) => ({ + type: "function", + function: { + name: tool.function.name, + description: tool.function.description, + parameters: tool.function.parameters, + }, + })); + chatOptions.stream = false; // Cannot set stream = true for tools calls + } + const response = await this.fetch(this.getEndpoint("api/chat"), { method: "POST", headers: { "Content-Type": "application/json", Authorization: `Bearer ${this.apiKey}`, }, - body: JSON.stringify(this._getChatOptions(options, messages)), + body: JSON.stringify(chatOptions), signal, }); - let buffer = ""; - for await (const value of streamResponse(response)) { - // Append the received chunk to the buffer - buffer += value; - // Split the buffer into individual JSON chunks - const chunks = buffer.split("\n"); - buffer = chunks.pop() ?? ""; + function convertChatMessage(res: OllamaChatResponse): ChatMessage { + if ("error" in res) { + throw new Error(res.error); + } + if (res.message.role === "tool") { + throw new Error( + "Unexpected message received from ollama with role = tool", + ); + } + if (res.message.role === "assistant") { + const chatMessage: ChatMessage = { + role: "assistant", + content: res.message.content, + }; + if (res.message.tool_calls) { + // Continue handles the response as a tool call delta but + // But ollama returns the full object in one response with no streaming + chatMessage.toolCalls = res.message.tool_calls.map((tc) => ({ + type: "function", + function: { + name: tc.function.name, + arguments: JSON.stringify(tc.function.arguments), + }, + })); + } + return chatMessage; + } else { + return { + role: res.message.role, + content: res.message.content, + }; + } + } - for (let i = 0; i < chunks.length; i++) { - const chunk = chunks[i]; - if (chunk.trim() !== "") { - try { - const j = JSON.parse(chunk); - if (j.message?.content) { - yield { - role: "assistant", - content: j.message.content, - }; - } else if (j.error) { - throw new Error(j.error); + if (chatOptions.stream === false) { + const json = (await response.json()) as OllamaChatResponse; + yield convertChatMessage(json); + } else { + let buffer = ""; + for await (const value of streamResponse(response)) { + // Append the received chunk to the buffer + buffer += value; + // Split the buffer into individual JSON chunks + const chunks = buffer.split("\n"); + buffer = chunks.pop() ?? ""; + + for (let i = 0; i < chunks.length; i++) { + const chunk = chunks[i]; + if (chunk.trim() !== "") { + try { + const j = JSON.parse(chunk) as OllamaChatResponse; + const chatMessage = convertChatMessage(j); + } catch (e) { + throw new Error(`Error parsing Ollama response: ${e} ${chunk}`); } - } catch (e) { - throw new Error(`Error parsing Ollama response: ${e} ${chunk}`); } } } diff --git a/core/llm/toolSupport.ts b/core/llm/toolSupport.ts new file mode 100644 index 0000000000..da15aacc81 --- /dev/null +++ b/core/llm/toolSupport.ts @@ -0,0 +1,46 @@ +export const PROVIDER_TOOL_SUPPORT: Record< + string, + (model: string) => boolean | undefined +> = { + anthropic: (model) => { + if ( + ["claude-3-5", "claude-3.5"].some((part) => + model.toLowerCase().startsWith(part), + ) + ) { + return true; + } + }, + // https://ollama.com/search?c=tools + ollama: (model) => { + if ( + ["vision", "math", "guard", "mistrallite", "mistral-openorca"].some( + (part) => model.toLowerCase().includes(part), + ) + ) { + return false; + } + if ( + [ + "llama3.3", + "qwq", + "llama3.2", + "llama3.1", + "qwen2", + "mixtral", + "command-r", + "smollm2", + "hermes3", + "athene-v2", + "nemotron", + "llama3-groq", + "granite3", + "aya-expanse", + "firefunction-v2", + "mistral", + ].some((part) => model.toLowerCase().startsWith(part)) + ) { + return true; + } + }, +}; diff --git a/docs/docs/customize/tools.md b/docs/docs/customize/tools.md index 878432c11b..020c38459f 100644 --- a/docs/docs/customize/tools.md +++ b/docs/docs/customize/tools.md @@ -4,7 +4,12 @@ description: Tool use and customization keywords: [tool, use, function calling, claude, automatic] --- -Tools allow Continue to take action in your IDE and beyond (when you give permission). Currently, they are only supported with [Anthropic](./model-providers/top-level/anthropic.md). To use tools, click on the icon in the input toolbar like below. +Tools allow Continue to take action in your IDE and beyond (when you give permission). Currently, they are only supported for the following providers: + +- [Anthropic](./model-providers/top-level/anthropic.md) - see recommended models [here](https://docs.anthropic.com/en/docs/build-with-claude/tool-use#choosing-a-model). +- [Ollama](./model-providers/top-level/ollama.md) - see recommended models [here](https://ollama.com/search?c=tools). + +To use tools, click on the icon in the input toolbar like below. ![tools](/img/tool-use-example.png) diff --git a/extensions/vscode/config_schema.json b/extensions/vscode/config_schema.json index d5e327ad44..d74b0976f7 100644 --- a/extensions/vscode/config_schema.json +++ b/extensions/vscode/config_schema.json @@ -3166,11 +3166,6 @@ "description": "Use Chromium to crawl docs locally. Useful if the default Cheerio crawler fails on sites that require JavaScript rendering. Downloads and installs Chromium to `~/.continue/.utils`.", "default": false }, - "useTools": { - "type": "boolean", - "description": "Allow tool use. Currently only supported with Claude 3.5 Sonnet", - "default": true - }, "modelContextProtocolServers": { "type": "array", "items": { diff --git a/gui/src/components/mainInput/ContinueInputBox.tsx b/gui/src/components/mainInput/ContinueInputBox.tsx index 941045b31f..e7ca65deb0 100644 --- a/gui/src/components/mainInput/ContinueInputBox.tsx +++ b/gui/src/components/mainInput/ContinueInputBox.tsx @@ -81,9 +81,7 @@ function ContinueInputBox(props: ContinueInputBoxProps) { const availableContextProviders = useAppSelector( (state) => state.config.config.contextProviders, ); - const useTools = useAppSelector( - (state) => state.config.config.experimental?.useTools !== false, - ); + const useTools = useAppSelector((state) => state.ui.useTools); const editModeState = useAppSelector((state) => state.editModeState); const filteredSlashCommands = props.isEditMode ? [] : availableSlashCommands; @@ -115,7 +113,7 @@ function ContinueInputBox(props: ContinueInputBoxProps) { enterText: editModeState.editStatus === "accepting" ? "Retry" : "Edit", } : { - hideTools: !useTools, + // hideTools: false, }; return ( diff --git a/gui/src/context/IdeMessenger.ts b/gui/src/context/IdeMessenger.ts index 1955f4f4c4..3ae73c2e1a 100644 --- a/gui/src/context/IdeMessenger.ts +++ b/gui/src/context/IdeMessenger.ts @@ -186,6 +186,7 @@ export class IdeMessenger implements IIdeMessenger { }) => { if (event.data.messageId === messageId) { const responseData = event.data.data; + if (responseData.done) { window.removeEventListener("message", handler); done = true; diff --git a/gui/src/redux/slices/sessionSlice.ts b/gui/src/redux/slices/sessionSlice.ts index 6fcb0a6225..a0c05caf54 100644 --- a/gui/src/redux/slices/sessionSlice.ts +++ b/gui/src/redux/slices/sessionSlice.ts @@ -19,6 +19,8 @@ import { Session, SessionMetadata, ToolCall, + ToolCallDelta, + ToolCallState, } from "core"; import { NEW_SESSION_TITLE } from "core/util/constants"; import { incrementalParseJson } from "core/util/incrementalParseJson"; @@ -283,6 +285,28 @@ export const sessionSlice = createSlice({ if (state.history.length) { const lastItem = state.history[state.history.length - 1]; const lastMessage = lastItem.message; + + function toolCallDeltaToState( + toolCallDelta: ToolCallDelta, + ): ToolCallState { + const [_, parsedArgs] = incrementalParseJson( + toolCallDelta.function?.arguments ?? "{}", + ); + return { + status: "generating", + toolCall: { + id: toolCallDelta.id ?? "", + type: toolCallDelta.type ?? "function", + function: { + name: toolCallDelta.function?.name ?? "", + arguments: toolCallDelta.function?.arguments ?? "", + }, + }, + toolCallId: toolCallDelta.id ?? "", + parsedArgs, + }; + } + for (const message of action.payload) { if ( message.role && @@ -301,20 +325,21 @@ export const sessionSlice = createSlice({ contextItems: [], }; if (message.role === "assistant" && message.toolCalls?.[0]) { - const toolCalls = message.toolCalls?.[0]; - if (toolCalls) { - const [_, parsedArgs] = incrementalParseJson( - message.toolCalls[0].function?.arguments ?? "{}", + const toolCallDelta = message.toolCalls[0]; + + if ( + toolCallDelta.id && + toolCallDelta.function?.arguments && + toolCallDelta.function?.name && + toolCallDelta.type + ) { + console.warn( + "Received streamed tool call without required fields", + toolCallDelta, ); - historyItem.toolCallState = { - status: "generating", - toolCall: message.toolCalls[0] as ToolCall, - toolCallId: message.toolCalls[0].id as ToolCall["id"], - parsedArgs, - }; } + historyItem.toolCallState = toolCallDeltaToState(toolCallDelta); } - state.history.push(historyItem); } else { // Add to the existing message @@ -322,40 +347,39 @@ export const sessionSlice = createSlice({ lastMessage.content += renderChatMessage(message); } else if ( message.role === "assistant" && - message.toolCalls && + message.toolCalls?.[0] && lastMessage.role === "assistant" ) { - if (!lastMessage.toolCalls) { - lastMessage.toolCalls = []; + // Intentionally only supporting one tool call for now. + const toolCallDelta = message.toolCalls[0]; + + // Update message tool call with delta data + const newArgs = + (lastMessage.toolCalls?.[0]?.function?.arguments ?? "") + + (toolCallDelta.function?.arguments ?? ""); + if (lastMessage.toolCalls?.[0]) { + lastMessage.toolCalls[0].function = { + name: + toolCallDelta.function?.name ?? + lastMessage.toolCalls[0].function?.name ?? + "", + arguments: newArgs, + }; + } else { + lastMessage.toolCalls = [toolCallDelta]; + } + + // Update current tool call state + if (!lastItem.toolCallState) { + console.warn( + "Received streamed tool call response prior to initial tool call delta", + ); + lastItem.toolCallState = toolCallDeltaToState(toolCallDelta); } - message.toolCalls.forEach((toolCall, i) => { - if (lastMessage.toolCalls!.length <= i) { - lastMessage.toolCalls!.push(toolCall); - } else { - if ( - toolCall?.function?.arguments && - lastMessage?.toolCalls?.[i]?.function?.arguments && - lastItem.toolCallState - ) { - lastMessage.toolCalls[i].function!.arguments += - toolCall.function.arguments; - - const [_, parsedArgs] = incrementalParseJson( - lastMessage.toolCalls[i].function!.arguments!, - ); - - lastItem.toolCallState.parsedArgs = parsedArgs; - lastItem.toolCallState.toolCall.function.arguments += - toolCall.function.arguments; - } else { - console.error( - "Unexpected tool call format received - this message added during gui strict null checks", - message, - lastMessage, - ); - } - } - }); + + const [_, parsedArgs] = incrementalParseJson(newArgs); + lastItem.toolCallState.parsedArgs = parsedArgs; + lastItem.toolCallState.toolCall.function.arguments = newArgs; } } } diff --git a/gui/src/redux/store.ts b/gui/src/redux/store.ts index 8a617d2e89..4432b3be6b 100644 --- a/gui/src/redux/store.ts +++ b/gui/src/redux/store.ts @@ -35,6 +35,7 @@ const saveSubsetFilters = [ // Don't persist any of the edit state for now createFilter("editModeState", []), createFilter("config", ["defaultModelTitle"]), + createFilter("ui", ["toolSettings", "useTools"]), ]; const migrations: MigrationManifest = { diff --git a/gui/src/redux/thunks/streamNormalInput.ts b/gui/src/redux/thunks/streamNormalInput.ts index d09a59fa22..bf3cfbf6eb 100644 --- a/gui/src/redux/thunks/streamNormalInput.ts +++ b/gui/src/redux/thunks/streamNormalInput.ts @@ -10,6 +10,7 @@ import { } from "../slices/sessionSlice"; import { ThunkApiType } from "../store"; import { callTool } from "./callTool"; +import { modelSupportsTools } from "core/llm/autodetect"; export const streamNormalInput = createAsyncThunk< void, @@ -32,17 +33,11 @@ export const streamNormalInput = createAsyncThunk< defaultModel.title, streamAborter.signal, messages, - useTools + useTools && modelSupportsTools(defaultModel.title, defaultModel.provider) ? { - tools: Object.keys(toolSettings) - .filter((tool) => toolSettings[tool] !== "disabled") - .map((toolName) => - state.config.config.tools.find( - (tool) => tool.function.name === toolName, - ), - ) - .filter((tool) => !!tool) - .map((tool) => tool!), // for type safety + tools: state.config.config.tools.filter( + (tool) => toolSettings[tool.function.name] !== "disabled", + ), } : {}, ); From 8536f7aabf7d896cd1b35a12b74a84e2ed28bc02 Mon Sep 17 00:00:00 2001 From: Dallin Romney Date: Fri, 10 Jan 2025 20:19:39 -0800 Subject: [PATCH 3/4] toolsapalooza cleanup --- .vscode/launch.json | 2 +- gui/src/components/mainInput/ContinueInputBox.tsx | 5 +---- gui/src/context/IdeMessenger.ts | 1 - 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 15720f315f..dfe474f40b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -23,7 +23,7 @@ "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "staging" for the preview deployment "CONTINUE_GLOBAL_DIR": "${workspaceFolder}/extensions/.continue-debug" // "local" for entirely local development of control plane/proxy - // "CONTROL_PLANE_ENV": "local" + // "CONTROL_PLANE_ENV": "staging" } }, { diff --git a/gui/src/components/mainInput/ContinueInputBox.tsx b/gui/src/components/mainInput/ContinueInputBox.tsx index e7ca65deb0..ba6ccef9ca 100644 --- a/gui/src/components/mainInput/ContinueInputBox.tsx +++ b/gui/src/components/mainInput/ContinueInputBox.tsx @@ -81,7 +81,6 @@ function ContinueInputBox(props: ContinueInputBoxProps) { const availableContextProviders = useAppSelector( (state) => state.config.config.contextProviders, ); - const useTools = useAppSelector((state) => state.ui.useTools); const editModeState = useAppSelector((state) => state.editModeState); const filteredSlashCommands = props.isEditMode ? [] : availableSlashCommands; @@ -112,9 +111,7 @@ function ContinueInputBox(props: ContinueInputBoxProps) { hideTools: true, enterText: editModeState.editStatus === "accepting" ? "Retry" : "Edit", } - : { - // hideTools: false, - }; + : {}; return (
diff --git a/gui/src/context/IdeMessenger.ts b/gui/src/context/IdeMessenger.ts index 3ae73c2e1a..1955f4f4c4 100644 --- a/gui/src/context/IdeMessenger.ts +++ b/gui/src/context/IdeMessenger.ts @@ -186,7 +186,6 @@ export class IdeMessenger implements IIdeMessenger { }) => { if (event.data.messageId === messageId) { const responseData = event.data.data; - if (responseData.done) { window.removeEventListener("message", handler); done = true; From 8293beb986123ba2e74c8c0c677e480b5c614f57 Mon Sep 17 00:00:00 2001 From: Dallin Romney Date: Fri, 10 Jan 2025 20:26:56 -0800 Subject: [PATCH 4/4] json schema types --- core/llm/llms/Ollama.ts | 6 +++--- core/package-lock.json | 7 +++++++ core/package.json | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/core/llm/llms/Ollama.ts b/core/llm/llms/Ollama.ts index 69db6179b8..16f71cf15b 100644 --- a/core/llm/llms/Ollama.ts +++ b/core/llm/llms/Ollama.ts @@ -1,4 +1,4 @@ -import { JSONSchema, JSONSchemaType } from "openai/lib/jsonschema.mjs"; +import { JSONSchema7Object, JSONSchema7Type } from "json-schema"; import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js"; import { renderChatMessage } from "../../util/messageContent.js"; @@ -12,7 +12,7 @@ type OllamaChatMessage = { tool_calls?: { function: { name: string; - arguments: Record; + arguments: Record; }; }[]; }; @@ -119,7 +119,7 @@ interface OllamaTool { function: { name: string; description?: string; - parameters?: JSONSchema; + parameters?: JSONSchema7Object; }; } diff --git a/core/package-lock.json b/core/package-lock.json index 4117d99cf2..18a318596a 100644 --- a/core/package-lock.json +++ b/core/package-lock.json @@ -83,6 +83,7 @@ "@types/jest": "^29.5.12", "@types/jquery": "^3.5.29", "@types/jsdom": "^21.1.6", + "@types/json-schema": "^7.0.15", "@types/mozilla-readability": "^0.2.1", "@types/mustache": "^4.2.5", "@types/node-fetch": "^2.6.11", @@ -5311,6 +5312,12 @@ "parse5": "^7.0.0" } }, + "node_modules/@types/json-schema": { + "version": "7.0.15", + "resolved": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "dev": true + }, "node_modules/@types/json5": { "version": "0.0.29", "resolved": "https://registry.npmjs.org/@types/json5/-/json5-0.0.29.tgz", diff --git a/core/package.json b/core/package.json index eb805363b9..e865e0ddbf 100644 --- a/core/package.json +++ b/core/package.json @@ -22,6 +22,7 @@ "@types/jest": "^29.5.12", "@types/jquery": "^3.5.29", "@types/jsdom": "^21.1.6", + "@types/json-schema": "^7.0.15", "@types/mozilla-readability": "^0.2.1", "@types/mustache": "^4.2.5", "@types/node-fetch": "^2.6.11",