diff --git a/README.md b/README.md index d6d8b24..99f6e45 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Trying to keep things as native as possible - reusing and integrating well with - properly working undo (response can be undone with a single `u`) - **Infinitely extensible** via hook functions specified as part of the config - hooks have access to everything in the plugin and are automatically registered as commands - - see [4. Configuration](#4-configuration) and [Extend functionality](#extend-functionality) sections for details + - see [5. Configuration](#5-configuration) and [Extend functionality](#extend-functionality) sections for details - **Minimum dependencies** (`neovim`, `curl`, `grep` and optionally `sox`) - zero dependencies on other lua plugins to minimize chance of breakage - **ChatGPT like sessions** @@ -121,7 +121,79 @@ The OpenAI API key can be passed to the plugin in multiple ways: If `openai_api_key` is a table, Gp runs it asynchronously to avoid blocking Neovim (password managers can take a second or two). -## 3. Dependencies +## 3. Multiple providers +The following LLM providers are currently supported besides OpenAI: + +- [Ollama](https://github.com/ollama/ollama) for local/offline open-source models. The plugin assumes you have the Ollama service up and running with configured models available (the default Ollama agent uses Llama3). +- [GitHub Copilot](https://github.com/settings/copilot) with a Copilot license ([zbirenbaum/copilot.lua](https://github.com/zbirenbaum/copilot.lua) or [github/copilot.vim](https://github.com/github/copilot.vim) for autocomplete). You can access the underlying GPT-4 model without paying anything extra (essentially unlimited GPT-4 access). +- [Perplexity.ai](https://www.perplexity.ai/pro) Pro users have $5/month free API credits available (the default PPLX agent uses Mixtral-8x7b). +- [Anthropic](https://www.anthropic.com/api) to access Claude models, which currently outperform GPT-4 in some benchmarks. +- [Google Gemini](https://ai.google.dev/) with a quite generous free range but some geo-restrictions (EU). +- Any other "OpenAI chat/completions" compatible endpoint (Azure, LM Studio, etc.) + +Below is an example of the relevant configuration part enabling some of these. The `secret` field has the same capabilities as `openai_api_key` (which is still supported for compatibility). + +```lua + providers = { + openai = { + endpoint = "https://api.openai.com/v1/chat/completions", + secret = os.getenv("OPENAI_API_KEY"), + }, + + -- azure = {...}, + + copilot = { + endpoint = "https://api.githubcopilot.com/chat/completions", + secret = { + "bash", + "-c", + "cat ~/.config/github-copilot/hosts.json | sed -e 's/.*oauth_token...//;s/\".*//'", + }, + }, + + pplx = { + endpoint = "https://api.perplexity.ai/chat/completions", + secret = os.getenv("PPLX_API_KEY"), + }, + + ollama = { + endpoint = "http://localhost:11434/v1/chat/completions", + }, + + googleai = { + endpoint = "https://generativelanguage.googleapis.com/v1beta/models/{{model}}:streamGenerateContent?key={{secret}}", + secret = os.getenv("GOOGLEAI_API_KEY"), + }, + + anthropic = { + endpoint = "https://api.anthropic.com/v1/messages", + secret = os.getenv("ANTHROPIC_API_KEY"), + }, + }, +``` + +Each of these providers has some agents preconfigured. Below is an example of how to disable predefined ChatGPT3-5 agent and create a custom one. If the `provider` field is missing, OpenAI is assumed for backward compatibility. + +```lua + agents = { + { + name = "ChatGPT3-5", + disable = true, + }, + { + name = "MyCustomAgent", + provider = "copilot", + chat = true, + command = true, + model = { model = "gpt-4-turbo" }, + system_prompt = "Answer any query with just: Sure thing..", + }, + }, + +``` + + +## 4. Dependencies The core plugin only needs `curl` installed to make calls to OpenAI API and `grep` for ChatFinder. So Linux, BSD and Mac OS should be covered. @@ -133,7 +205,7 @@ Voice commands (`:GpWhisper*`) depend on `SoX` (Sound eXchange) to handle audio - Redhat/CentOS: `yum install sox` - NixOS: `nix-env -i sox` -## 4. Configuration +## 5. Configuration Bellow is a linked snippet with the default values, but I suggest starting with minimal config possible (just `openai_api_key` if you don't have `OPENAI_API_KEY` env set up). Defaults change over time to improve things, options might get deprecated and so on - it's better to change only things where the default doesn't fit your needs. diff --git a/lua/gp/config.lua b/lua/gp/config.lua index 4ff72a4..013d6a6 100644 --- a/lua/gp/config.lua +++ b/lua/gp/config.lua @@ -5,6 +5,20 @@ -- Default config -------------------------------------------------------------------------------- +local default_chat_system_prompt = "You are a general AI assistant.\n\n" + .. "The user provided the additional info about how they would like you to respond:\n\n" + .. "- If you're unsure don't guess and say you don't know instead.\n" + .. "- Ask question if you need clarification to provide better answer.\n" + .. "- Think deeply and carefully from first principles step by step.\n" + .. "- Zoom out first to see the big picture and then zoom in to details.\n" + .. "- Use Socratic method to improve your thinking and coding skills.\n" + .. "- Don't elide any code from your output if the answer requires coding.\n" + .. "- Take a deep breath; You've got this!\n" + +local default_code_system_prompt = "You are an AI working as a code editor.\n\n" + .. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n" + .. "START AND END YOUR ANSWER WITH:\n\n```" + local config = { -- Please start with minimal config possible. -- Just openai_api_key if you don't have OPENAI_API_KEY env set up. @@ -17,9 +31,51 @@ local config = { -- openai_api_key: "sk-...", -- openai_api_key = os.getenv("env_name.."), openai_api_key = os.getenv("OPENAI_API_KEY"), - -- api endpoint (you can change this to azure endpoint) - openai_api_endpoint = "https://api.openai.com/v1/chat/completions", - -- openai_api_endpoint = "https://$URL.openai.azure.com/openai/deployments/{{model}}/chat/completions?api-version=2023-03-15-preview", + + -- at least one working provider is required + -- to disable a provider set it to empty table like openai = {} + providers = { + -- secrets can be strings or tables with command and arguments + -- secret = { "cat", "path_to/openai_api_key" }, + -- secret = { "bw", "get", "password", "OPENAI_API_KEY" }, + -- secret : "sk-...", + -- secret = os.getenv("env_name.."), + openai = { + endpoint = "https://api.openai.com/v1/chat/completions", + -- secret = os.getenv("OPENAI_API_KEY"), + }, + azure = { + -- endpoint = "https://$URL.openai.azure.com/openai/deployments/{{model}}/chat/completions", + -- secret = os.getenv("AZURE_API_KEY"), + }, + copilot = { + -- endpoint = "https://api.githubcopilot.com/chat/completions", + -- secret = { + -- "bash", + -- "-c", + -- "cat ~/.config/github-copilot/hosts.json | sed -e 's/.*oauth_token...//;s/\".*//'", + -- }, + }, + ollama = { + -- endpoint = "http://localhost:11434/v1/chat/completions", + }, + lmstudio = { + -- endpoint = "http://localhost:1234/v1/chat/completions", + }, + googleai = { + -- endpoint = "https://generativelanguage.googleapis.com/v1beta/models/{{model}}:streamGenerateContent?key={{secret}}", + -- secret = os.getenv("GOOGLEAI_API_KEY"), + }, + pplx = { + -- endpoint = "https://api.perplexity.ai/chat/completions", + -- secret = os.getenv("PPLX_API_KEY"), + }, + anthropic = { + -- endpoint = "https://api.anthropic.com/v1/messages", + -- secret = os.getenv("ANTHROPIC_API_KEY"), + }, + }, + -- prefix for all commands cmd_prefix = "Gp", -- optional curl parameters (for proxy, etc.) @@ -36,58 +92,164 @@ local config = { -- agents = { { name = "ChatGPT4" }, ... }, agents = { { - name = "ChatGPT4", + name = "ChatGPT4o", chat = true, command = false, -- string with model name or table with model name and parameters model = { model = "gpt-4o", temperature = 1.1, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) - system_prompt = "You are a general AI assistant.\n\n" - .. "The user provided the additional info about how they would like you to respond:\n\n" - .. "- If you're unsure don't guess and say you don't know instead.\n" - .. "- Ask question if you need clarification to provide better answer.\n" - .. "- Think deeply and carefully from first principles step by step.\n" - .. "- Zoom out first to see the big picture and then zoom in to details.\n" - .. "- Use Socratic method to improve your thinking and coding skills.\n" - .. "- Don't elide any code from your output if the answer requires coding.\n" - .. "- Take a deep breath; You've got this!\n", + system_prompt = default_chat_system_prompt, }, { + provider = "openai", name = "ChatGPT3-5", chat = true, command = false, -- string with model name or table with model name and parameters model = { model = "gpt-3.5-turbo", temperature = 1.1, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) - system_prompt = "You are a general AI assistant.\n\n" - .. "The user provided the additional info about how they would like you to respond:\n\n" - .. "- If you're unsure don't guess and say you don't know instead.\n" - .. "- Ask question if you need clarification to provide better answer.\n" - .. "- Think deeply and carefully from first principles step by step.\n" - .. "- Zoom out first to see the big picture and then zoom in to details.\n" - .. "- Use Socratic method to improve your thinking and coding skills.\n" - .. "- Don't elide any code from your output if the answer requires coding.\n" - .. "- Take a deep breath; You've got this!\n", + system_prompt = default_chat_system_prompt, + }, + { + provider = "copilot", + name = "ChatCopilot", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { model = "gpt-4", temperature = 1.1, top_p = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = default_chat_system_prompt, + }, + { + provider = "googleai", + name = "ChatGemini", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { model = "gemini-pro", temperature = 1.1, top_p = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = default_chat_system_prompt, + }, + { + provider = "pplx", + name = "ChatPerplexityMixtral", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { model = "mixtral-8x7b-instruct", temperature = 1.1, top_p = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = default_chat_system_prompt, + }, + { + provider = "anthropic", + name = "ChatClaude-3-Haiku", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { model = "claude-3-haiku-20240307", temperature = 0.8, top_p = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = default_chat_system_prompt, + }, + { + provider = "ollama", + name = "ChatOllamaLlama3", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { + model = "llama3", + num_ctx = 8192, + }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = "You are a general AI assistant.", }, { - name = "CodeGPT4", + provider = "lmstudio", + name = "ChatLMStudio", + chat = true, + command = false, + -- string with model name or table with model name and parameters + model = { + model = "dummy", + temperature = 0.97, + top_p = 1, + num_ctx = 8192, + }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = "You are a general AI assistant.", + }, + { + provider = "openai", + name = "CodeGPT4o", chat = false, command = true, -- string with model name or table with model name and parameters model = { model = "gpt-4o", temperature = 0.8, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) - system_prompt = "You are an AI working as a code editor.\n\n" - .. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n" - .. "START AND END YOUR ANSWER WITH:\n\n```", + system_prompt = default_code_system_prompt, }, { + provider = "openai", name = "CodeGPT3-5", chat = false, command = true, -- string with model name or table with model name and parameters model = { model = "gpt-3.5-turbo", temperature = 0.8, top_p = 1 }, -- system prompt (use this to specify the persona/role of the AI) - system_prompt = "You are an AI working as a code editor.\n\n" + system_prompt = default_code_system_prompt, + }, + { + provider = "copilot", + name = "CodeCopilot", + chat = false, + command = true, + -- string with the Copilot engine name or table with engine name and parameters if applicable + model = { model = "gpt-4", temperature = 0.8, top_p = 1, n = 1 }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = default_code_system_prompt, + }, + { + provider = "googleai", + name = "CodeGemini", + chat = false, + command = true, + -- string with model name or table with model name and parameters + model = { model = "gemini-pro", temperature = 0.8, top_p = 1 }, + system_prompt = default_code_system_prompt, + }, + { + provider = "pplx", + name = "CodePerplexityMixtral", + chat = false, + command = true, + -- string with model name or table with model name and parameters + model = { model = "mixtral-8x7b-instruct", temperature = 0.8, top_p = 1 }, + system_prompt = default_code_system_prompt, + }, + { + provider = "anthropic", + name = "CodeClaude-3-Haiku", + chat = false, + command = true, + -- string with model name or table with model name and parameters + model = { model = "claude-3-haiku-20240307", temperature = 0.8, top_p = 1 }, + system_prompt = default_code_system_prompt, + }, + { + provider = "ollama", + name = "CodeOllamaLlama3", + chat = false, + command = true, + -- string with the Copilot engine name or table with engine name and parameters if applicable + model = { + model = "llama3", + temperature = 1.9, + top_p = 1, + num_ctx = 8192, + }, + -- system prompt (use this to specify the persona/role of the AI) + system_prompt = "You are an AI working as a code editor providing answers.\n\n" + .. "Use 4 SPACES FOR INDENTATION.\n" .. "Please AVOID COMMENTARY OUTSIDE OF THE SNIPPET RESPONSE.\n" .. "START AND END YOUR ANSWER WITH:\n\n```", }, @@ -96,7 +258,7 @@ local config = { -- directory for storing chat files chat_dir = vim.fn.stdpath("data"):gsub("/$", "") .. "/gp/chats", -- chat user prompt prefix - chat_user_prefix = "🗨:", + chat_user_prefix = "💬:", -- chat assistant prompt prefix (static string or a table {static, template}) -- first string has to be static, second string can contain template {{agent}} -- just a static string is legacy and the [{{agent}}] element is added automatically @@ -106,7 +268,6 @@ local config = { chat_topic_gen_prompt = "Summarize the topic of our conversation above" .. " in two or three words. Respond only with those words.", -- chat topic model (string with model name or table with model name and parameters) - chat_topic_gen_model = "gpt-3.5-turbo-16k", -- explicitly confirm deletion of a chat file chat_confirm_delete = true, -- conceal model parameters in chat @@ -301,6 +462,12 @@ local config = { local copy = vim.deepcopy(plugin) local key = copy.config.openai_api_key copy.config.openai_api_key = key:sub(1, 3) .. string.rep("*", #key - 6) .. key:sub(-3) + for provider, _ in pairs(copy.providers) do + local s = copy.providers[provider].secret + if s and type(s) == "string" then + copy.providers[provider].secret = s:sub(1, 3) .. string.rep("*", #s - 6) .. s:sub(-3) + end + end local plugin_info = string.format("Plugin structure:\n%s", vim.inspect(copy)) local params_info = string.format("Command params:\n%s", vim.inspect(params)) local lines = vim.split(plugin_info .. "\n" .. params_info, "\n") diff --git a/lua/gp/health.lua b/lua/gp/health.lua index 719ec3b..6d66a2b 100644 --- a/lua/gp/health.lua +++ b/lua/gp/health.lua @@ -15,6 +15,7 @@ function M.check() vim.health.error("require('gp').setup() has not been called") end + --TODO: obsolete ---@diagnostic disable-next-line: undefined-field local api_key = gp.config.openai_api_key diff --git a/lua/gp/init.lua b/lua/gp/init.lua index 4cc668c..c7ae41b 100644 --- a/lua/gp/init.lua +++ b/lua/gp/init.lua @@ -18,6 +18,24 @@ local deprecated = { command_prompt_prefix = "`command_prompt_prefix`\nPlease use `command_prompt_prefix_template`" .. " with support for \n`{{agent}}` variable so you know which agent is currently active", whisper_max_time = "`whisper_max_time`\nPlease use fully customizable `whisper_rec_cmd`", + + openai_api_endpoint = "`openai_api_endpoint`\n\n" + .. "********************************************************************************\n" + .. "********************************************************************************\n" + .. "Gp.nvim finally supports multiple LLM providers; sorry it took so long.\n" + .. "I've dreaded merging this, because I hate breaking people's setups.\n" + .. "But this change is necessary for future improvements.\n\n" + .. "Migration hints are below; for more help, try the readme docs or open an issue.\n" + .. "********************************************************************************\n" + .. "********************************************************************************\n\n" + .. "If you're using the `https://api.openai.com/v1/chat/completions` endpoint,\n" + .. "just drop `openai_api_endpoint` in your config and you're done." + .. "\n\nOtherwise sorry for probably breaking your setup, " + .. "please use `endpoint` and `secret` fields in:\n\nproviders " + .. "= {\n openai = {\n endpoint = '...',\n secret = '...'\n }," + .. "\n -- azure = {...},\n -- copilot = {...},\n -- ollama = {...},\n -- googleai= {...},\n -- pplx = {...},\n -- anthropic = {...},\n},\n" + .. "\nThe `openai_api_key` is still supported for backwards compatibility,\n" + .. "and automatically converted to `providers.openai.secret` if the new config is not set.", } -------------------------------------------------------------------------------- @@ -657,6 +675,55 @@ M.append_selection = function(params, origin_buf, target_buf) vim.api.nvim_buf_set_lines(target_buf, last_content_line, -1, false, lines) end +function M.refresh_copilot_bearer() + if not M.providers.copilot or not M.providers.copilot.secret then + return + end + local secret = M.providers.copilot.secret + + if type(secret) == "table" then + return + end + + local bearer = M._state.copilot_bearer or {} + if bearer.token and bearer.expires_at and bearer.expires_at > os.time() then + return + end + + local curl_params = vim.deepcopy(M.config.curl_params or {}) + local args = { + "-s", + "-v", + "https://api.github.com/copilot_internal/v2/token", + "-H", + "Content-Type: application/json", + "-H", + "accept: */*", + "-H", + "authorization: token " .. secret, + "-H", + "editor-version: vscode/1.90.2", + "-H", + "editor-plugin-version: copilot-chat/0.17.2024062801", + "-H", + "user-agent: GitHubCopilotChat/0.17.2024062801", + } + + for _, arg in ipairs(args) do + table.insert(curl_params, arg) + end + + M._H.process(nil, "curl", curl_params, function(code, signal, stdout, stderr) + if code ~= 0 then + M.error(string.format("Copilot bearer resolve exited: %d, %d", code, signal, stderr)) + return + end + + M._state.copilot_bearer = vim.json.decode(stdout) + M.refresh_state() + end, nil, nil) +end + -- setup function M._setup_called = false ---@param opts table | nil # table with options @@ -676,12 +743,12 @@ M.setup = function(opts) M.config = vim.deepcopy(config) -- merge nested tables - local mergeTables = { "hooks", "agents", "image_agents" } + local mergeTables = { "hooks", "agents", "image_agents", "providers" } for _, tbl in ipairs(mergeTables) do M[tbl] = M[tbl] or {} ---@diagnostic disable-next-line: param-type-mismatch for k, v in pairs(M.config[tbl]) do - if tbl == "hooks" then + if tbl == "hooks" or tbl == "providers" then M[tbl][k] = v elseif tbl == "agents" or tbl == "image_agents" then M[tbl][v.name] = v @@ -693,6 +760,14 @@ M.setup = function(opts) for k, v in pairs(opts[tbl]) do if tbl == "hooks" then M[tbl][k] = v + elseif tbl == "providers" then + M[tbl][k] = M[tbl][k] or {} + for pk, pv in pairs(v) do + M[tbl][k][pk] = pv + end + if next(v) == nil then + M[tbl][k] = nil + end elseif tbl == "agents" or tbl == "image_agents" then M[tbl][v.name] = v end @@ -753,15 +828,30 @@ M.setup = function(opts) end end + -- remove invalid providers + for name, provider in pairs(M.providers) do + if type(provider) ~= "table" or not provider.endpoint then + M.providers[name] = nil + end + end + -- prepare agent completions M._chat_agents = {} M._command_agents = {} for name, agent in pairs(M.agents) do - if agent.command then - table.insert(M._command_agents, name) + if not M.agents[name].provider then + M.agents[name].provider = "openai" end - if agent.chat then - table.insert(M._chat_agents, name) + + if M.providers[M.agents[name].provider] then + if agent.command then + table.insert(M._command_agents, name) + end + if agent.chat then + table.insert(M._chat_agents, name) + end + else + M.agents[name] = nil end end table.sort(M._chat_agents) @@ -828,9 +918,48 @@ M.setup = function(opts) M.error("curl is not installed, run :checkhealth gp") end - if type(M.config.openai_api_key) == "table" then + for name, _ in pairs(M.providers) do + M.resolve_secret(name) + end + if not M.providers.openai then + M.providers.openai = {} + M.resolve_secret("openai", function() + M.providers.openai = nil + end) + end +end + +---@provider string # provider name +function M.resolve_secret(provider, callback) + local post_process = function() + local p = M.providers[provider] + if p.secret and type(p.secret) == "string" then + p.secret = p.secret:gsub("^%s*(.-)%s*$", "%1") + end + + if provider == "copilot" then + M.refresh_copilot_bearer() + end + + -- backwards compatibility + if provider == "openai" then + M.config.openai_api_key = M.providers[provider].secret + end + + if callback then + callback() + end + end + + -- backwards compatibility + if provider == "openai" then + M.providers[provider].secret = M.providers[provider].secret or M.config.openai_api_key + end + + local secret = M.providers[provider].secret + if secret and type(secret) == "table" then ---@diagnostic disable-next-line: param-type-mismatch - local copy = vim.deepcopy(M.config.openai_api_key) + local copy = vim.deepcopy(secret) ---@diagnostic disable-next-line: param-type-mismatch local cmd = table.remove(copy, 1) local args = copy @@ -840,18 +969,23 @@ M.setup = function(opts) local content = stdout_data:match("^%s*(.-)%s*$") if not string.match(content, "%S") then M.warning( - "response from the config.openai_api_key command " - .. vim.inspect(M.config.openai_api_key) + "response from the config.providers." + .. provider + .. ".secret command " + .. vim.inspect(secret) .. " is empty" ) return end - M.config.openai_api_key = content + M.providers[provider].secret = content + post_process() else M.warning( - "config.openai_api_key command " - .. vim.inspect(M.config.openai_api_key) - .. " to retrieve openai_api_key failed:\ncode: " + "config.providers." + .. provider + .. ".secret command " + .. vim.inspect(secret) + .. " to retrieve the secret failed:\ncode: " .. code .. ", signal: " .. signal @@ -863,10 +997,11 @@ M.setup = function(opts) end end) else - M.valid_api_key() + post_process() end end +-- TODO: obsolete M.valid_api_key = function() local api_key = M.config.openai_api_key @@ -906,9 +1041,20 @@ M.refresh_state = function() M._state.image_agent = M._image_agents[1] end + local bearer = M._state.copilot_bearer or state.copilot_bearer or nil + if bearer and bearer.expires_at and bearer.expires_at < os.time() then + bearer = nil + M.refresh_copilot_bearer() + end + M._state.copilot_bearer = bearer + M.table_to_file(M._state, state_file) M.prepare_commands() + + local buf = vim.api.nvim_get_current_buf() + local file_name = vim.api.nvim_buf_get_name(buf) + M.display_chat_agent(buf, file_name) end M.Target = { @@ -974,7 +1120,16 @@ M.prepare_commands = function() template = M.config.template_prepend end end - M.Prompt(params, target, agent.cmd_prefix, agent.model, template, agent.system_prompt, whisper) + M.Prompt( + params, + target, + agent.cmd_prefix, + agent.model, + template, + agent.system_prompt, + whisper, + agent.provider + ) end M.cmd[command] = function(params) @@ -1002,10 +1157,10 @@ end ---@param messages table ---@param model string | table | nil ---@param default_model string | table -M.prepare_payload = function(messages, model, default_model) +---@param provider string | nil +M.prepare_payload = function(messages, model, default_model, provider) model = model or default_model - -- if model is a string if type(model) == "string" then return { model = model, @@ -1014,7 +1169,89 @@ M.prepare_payload = function(messages, model, default_model) } end - -- if model is a table + if provider == "googleai" then + for i, message in ipairs(messages) do + if message.role == "system" then + messages[i].role = "user" + end + if message.role == "assistant" then + messages[i].role = "model" + end + if message.content then + messages[i].parts = { + { + text = message.content, + }, + } + messages[i].content = nil + end + end + local i = 1 + while i < #messages do + if messages[i].role == messages[i + 1].role then + table.insert(messages[i].parts, { + text = messages[i + 1].parts[1].text, + }) + table.remove(messages, i + 1) + else + i = i + 1 + end + end + local payload = { + contents = messages, + safetySettings = { + { + category = "HARM_CATEGORY_HARASSMENT", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_HATE_SPEECH", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_SEXUALLY_EXPLICIT", + threshold = "BLOCK_NONE", + }, + { + category = "HARM_CATEGORY_DANGEROUS_CONTENT", + threshold = "BLOCK_NONE", + }, + }, + generationConfig = { + temperature = math.max(0, math.min(2, model.temperature or 1)), + maxOutputTokens = model.max_tokens or 8192, + topP = math.max(0, math.min(1, model.top_p or 1)), + topK = model.top_k or 100, + }, + model = model.model, + } + return payload + end + + if provider == "anthropic" then + local system = "" + local i = 1 + while i < #messages do + if messages[i].role == "system" then + system = system .. messages[i].content .. "\n" + table.remove(messages, i) + else + i = i + 1 + end + end + + local payload = { + model = model.model, + stream = true, + messages = messages, + system = system, + max_tokens = model.max_tokens or 4096, + temperature = math.max(0, math.min(2, model.temperature or 1)), + top_p = math.max(0, math.min(1, model.top_p or 1)), + } + return payload + end + return { model = model.model, stream = true, @@ -1057,10 +1294,11 @@ end -- gpt query ---@param buf number | nil # buffer number ----@param payload table # payload for openai api +---@param provider string # provider name +---@param payload table # payload for api ---@param handler function # response handler ---@param on_exit function | nil # optional on_exit handler -M.query = function(buf, payload, handler, on_exit) +M.query = function(buf, provider, payload, handler, on_exit) -- make sure handler is a function if type(handler) ~= "function" then M.error( @@ -1077,6 +1315,7 @@ M.query = function(buf, payload, handler, on_exit) M._queries[qid] = { timestamp = os.time(), buf = buf, + provider = provider, payload = payload, handler = handler, on_exit = on_exit, @@ -1106,14 +1345,36 @@ M.query = function(buf, payload, handler, on_exit) qt.raw_response = qt.raw_response .. line .. "\n" end line = line:gsub("^data: ", "") - if line:match("chat%.completion%.chunk") then + local content = "" + if line:match("choices") and line:match("delta") and line:match("content") then line = vim.json.decode(line) - local content = line.choices[1].delta.content - if content ~= nil then - qt.response = qt.response .. content - handler(qid, content) + if line.choices[1] and line.choices[1].delta and line.choices[1].delta.content then + content = line.choices[1].delta.content + end + end + + if qt.provider == "anthropic" and line:match('"text":') then + if line:match("content_block_start") or line:match("content_block_delta") then + line = vim.json.decode(line) + if line.delta and line.delta.text then + content = line.delta.text + end + if line.content_block and line.content_block.text then + content = line.content_block.text + end + end + end + + if qt.provider == "googleai" then + if line:match('"text":') then + content = vim.json.decode("{" .. line .. "}").text end end + + if content and type(content) == "string" then + qt.response = qt.response .. content + handler(qid, content) + end end end @@ -1125,7 +1386,7 @@ M.query = function(buf, payload, handler, on_exit) end if err then - M.error("OpenAI query stdout error: " .. vim.inspect(err)) + M.error(qt.provider .. " query stdout error: " .. vim.inspect(err)) elseif chunk then -- add the incoming chunk to the buffer buffer = buffer .. chunk @@ -1145,7 +1406,7 @@ M.query = function(buf, payload, handler, on_exit) end if qt.response == "" then - M.error("OpenAI query response is empty: \n" .. vim.inspect(qt.raw_response)) + M.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response)) end -- optional on_exit handler @@ -1161,8 +1422,65 @@ M.query = function(buf, payload, handler, on_exit) end end - -- try to replace model in endpoint (for azure) - local endpoint = M._H.template_replace(M.config.openai_api_endpoint, "{{model}}", payload.model) + ---TODO: this could be moved to a separate function returning endpoint and headers + local endpoint = M.providers[provider].endpoint + local bearer = M.providers[provider].secret + local headers = {} + + if provider == "copilot" then + M.refresh_copilot_bearer() + ---@diagnostic disable-next-line: undefined-field + bearer = M._state.copilot_bearer.token or "" + headers = { + "-H", + "editor-version: vscode/1.85.1", + "-H", + "Authorization: Bearer " .. bearer, + } + end + + if provider == "openai" then + headers = { + "-H", + "Authorization: Bearer " .. bearer, + -- backwards compatibility + "-H", + "api-key: " .. bearer, + } + end + + if provider == "pplx" then + headers = { + "-H", + "Authorization: Bearer " .. bearer, + } + end + + if provider == "googleai" then + headers = {} + endpoint = M._H.template_replace(endpoint, "{{secret}}", bearer) + endpoint = M._H.template_replace(endpoint, "{{model}}", payload.model) + payload.model = nil + end + + if provider == "anthropic" then + headers = { + "-H", + "x-api-key: " .. bearer, + "-H", + "anthropic-version: 2023-06-01", + "-H", + "anthropic-beta: messages-2023-12-15", + } + end + + if provider == "azure" then + headers = { + "-H", + "api-key: " .. bearer, + } + endpoint = M._H.template_replace(endpoint, "{{model}}", payload.model) + end local curl_params = vim.deepcopy(M.config.curl_params or {}) local args = { @@ -1171,11 +1489,6 @@ M.query = function(buf, payload, handler, on_exit) endpoint, "-H", "Content-Type: application/json", - -- api-key is for azure, authorization is for openai - "-H", - "Authorization: Bearer " .. M.config.openai_api_key, - "-H", - "api-key: " .. M.config.openai_api_key, "-d", vim.json.encode(payload), --[[ "--doesnt_exist" ]] @@ -1185,6 +1498,10 @@ M.query = function(buf, payload, handler, on_exit) table.insert(curl_params, arg) end + for _, header in ipairs(headers) do + table.insert(curl_params, header) + end + M._H.process(buf, "curl", curl_params, nil, out_reader(), nil) end @@ -1397,6 +1714,29 @@ M.not_chat = function(buf, file_name) return nil end +M.display_chat_agent = function(buf, file_name) + if M.not_chat(buf, file_name) then + return + end + + if buf ~= vim.api.nvim_get_current_buf() then + return + end + + local ns_id = vim.api.nvim_create_namespace("GpChatExt_" .. file_name) + vim.api.nvim_buf_clear_namespace(buf, ns_id, 0, -1) + + vim.api.nvim_buf_set_extmark(buf, ns_id, 0, 0, { + strict = false, + right_gravity = true, + virt_text_pos = "right_align", + virt_text = { + { "Current Agent: [" .. M._state.chat_agent .. "]", "DiagnosticHint" }, + }, + hl_mode = "combine", + }) +end + M.prep_chat = function(buf, file_name) if M.not_chat(buf, file_name) then return @@ -1495,8 +1835,21 @@ M.buf_handler = function() local file_name = vim.api.nvim_buf_get_name(buf) M.prep_chat(buf, file_name) + M.display_chat_agent(buf, file_name) M.prep_context(buf, file_name) end, gid) + + _H.autocmd({ "WinEnter" }, nil, function(event) + local buf = event.buf + + if not vim.api.nvim_buf_is_valid(buf) then + return + end + + local file_name = vim.api.nvim_buf_get_name(buf) + + M.display_chat_agent(buf, file_name) + end, gid) end M.BufTarget = { @@ -1922,12 +2275,17 @@ M.chat_respond = function(params) ---@diagnostic disable-next-line: cast-local-type agent_suffix = M._H.template_render(agent_suffix, { ["{{agent}}"] = agent_name }) + local old_default_user_prefix = "🗨:" for index = start_index, end_index do local line = lines[index] if line:sub(1, #M.config.chat_user_prefix) == M.config.chat_user_prefix then table.insert(messages, { role = role, content = content }) role = "user" content = line:sub(#M.config.chat_user_prefix + 1) + elseif line:sub(1, #old_default_user_prefix) == old_default_user_prefix then + table.insert(messages, { role = role, content = content }) + role = "user" + content = line:sub(#old_default_user_prefix + 1) elseif line:sub(1, #agent_prefix) == agent_prefix then table.insert(messages, { role = role, content = content }) role = "assistant" @@ -1970,7 +2328,8 @@ M.chat_respond = function(params) -- call the model and write response M.query( buf, - M.prepare_payload(messages, headers.model, agent.model), + agent.provider, + M.prepare_payload(messages, headers.model, agent.model, agent.provider), M.create_handler(buf, win, M._H.last_content_line(buf), true, "", not M.config.chat_free_cursor), vim.schedule_wrap(function(qid) local qt = M.get_query(qid) @@ -2012,7 +2371,8 @@ M.chat_respond = function(params) -- call the model M.query( nil, - M.prepare_payload(messages, nil, M.config.chat_topic_gen_model), + agent.provider, + M.prepare_payload(messages, nil, agent.model, agent.provider), topic_handler, vim.schedule_wrap(function() -- get topic from invisible buffer @@ -2418,20 +2778,24 @@ M.cmd.NextAgent = function() agent_list = M._command_agents end + local set_agent = function(agent_name) + if is_chat then + M._state.chat_agent = agent_name + M.info("Chat agent: " .. agent_name) + else + M._state.command_agent = agent_name + M.info("Command agent: " .. agent_name) + end + M.refresh_state() + end + for i, agent_name in ipairs(agent_list) do if agent_name == current_agent then - local next_agent = agent_list[i % #agent_list + 1] - if is_chat then - M._state.chat_agent = next_agent - M.info("Chat agent: " .. next_agent) - else - M._state.command_agent = next_agent - M.info("Command agent: " .. next_agent) - end - M.refresh_state() + set_agent(agent_list[i % #agent_list + 1]) return end end + set_agent(agent_list[1]) end ---@return table # { cmd_prefix, name, model, system_prompt } @@ -2441,7 +2805,14 @@ M.get_command_agent = function() local name = M._state.command_agent local model = M.agents[name].model local system_prompt = M.agents[name].system_prompt - return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt } + local provider = M.agents[name].provider + return { + cmd_prefix = cmd_prefix, + name = name, + model = model, + system_prompt = system_prompt, + provider = provider, + } end ---@return table # { cmd_prefix, name, model, system_prompt } @@ -2451,7 +2822,14 @@ M.get_chat_agent = function() local name = M._state.chat_agent local model = M.agents[name].model local system_prompt = M.agents[name].system_prompt - return { cmd_prefix = cmd_prefix, name = name, model = model, system_prompt = system_prompt } + local provider = M.agents[name].provider + return { + cmd_prefix = cmd_prefix, + name = name, + model = model, + system_prompt = system_prompt, + provider = provider, + } end M.cmd.Context = function(params) @@ -2496,7 +2874,7 @@ M.cmd.Context = function(params) M._H.feedkeys("G", "xn") end -M.Prompt = function(params, target, prompt, model, template, system_template, whisper) +M.Prompt = function(params, target, prompt, model, template, system_template, whisper, provider) -- enew, new, vnew, tabnew should be resolved into table if type(target) == "function" then target = target() @@ -2764,7 +3142,8 @@ M.Prompt = function(params, target, prompt, model, template, system_template, wh local agent = M.get_command_agent() M.query( buf, - M.prepare_payload(messages, model, agent.model), + provider, + M.prepare_payload(messages, model, agent.model, agent.provider), handler, vim.schedule_wrap(function(qid) on_exit(qid)