diff --git a/dingllm.lua b/dingllm.lua new file mode 100644 index 0000000..c3efe16 --- /dev/null +++ b/dingllm.lua @@ -0,0 +1,221 @@ + -- This line initializes a module table to hold all functions and variables. +local M = {} +-- Here we're requiring the Job module from plenary for managing asynchronous jobs. +local Job = require 'plenary.job' +-- Creates a unique namespace for Neovim buffer extmarks, used for tracking changes. +local ns_id = vim.api.nvim_create_namespace 'dingllm' + +-- This function retrieves an API key from the environment variables. +local function get_api_key(name) + return os.getenv(name) +end + +-- Retrieves all lines from the start of the buffer up to the current cursor position. +function M.get_lines_until_cursor() + local current_buffer = vim.api.nvim_get_current_buf() + local current_window = vim.api.nvim_get_current_win() + local cursor_position = vim.api.nvim_win_get_cursor(current_window) + local row = cursor_position[1] + + local lines = vim.api.nvim_buf_get_lines(current_buffer, 0, row, true) + + return table.concat(lines, '\n') +end + +-- This function captures the currently selected text in visual mode. +function M.get_visual_selection() + local _, srow, scol = unpack(vim.fn.getpos 'v') + local _, erow, ecol = unpack(vim.fn.getpos '.') + + if vim.fn.mode() == 'V' then + if srow > erow then + return vim.api.nvim_buf_get_lines(0, erow - 1, srow, true) + else + return vim.api.nvim_buf_get_lines(0, srow - 1, erow, true) + end + end + + if vim.fn.mode() == 'v' then + if srow < erow or (srow == erow and scol <= ecol) then + return vim.api.nvim_buf_get_text(0, srow - 1, scol - 1, erow - 1, ecol, {}) + else + return vim.api.nvim_buf_get_text(0, erow - 1, ecol - 1, srow - 1, scol, {}) + end + end + + if vim.fn.mode() == '\22' then + local lines = {} + if srow > erow then + srow, erow = erow, srow + end + if scol > ecol then + scol, ecol = ecol, scol + end + for i = srow, erow do + table.insert(lines, vim.api.nvim_buf_get_text(0, i - 1, math.min(scol - 1, ecol), i - 1, math.max(scol - 1, ecol), {})[1]) + end + return lines + end +end + +-- Constructs curl arguments for the Anthropic API, including API key and data formatting. +function M.make_anthropic_spec_curl_args(opts, prompt, system_prompt) + local url = opts.url + local api_key = opts.api_key_name and get_api_key(opts.api_key_name) + local data = { + system = system_prompt, + messages = { { role = 'user', content = prompt } }, + model = opts.model, + stream = true, + max_tokens = 4096, + } + local args = { '-N', '-X', 'POST', '-H', 'Content-Type: application/json', '-d', vim.json.encode(data) } + if api_key then + table.insert(args, '-H') + table.insert(args, 'x-api-key: ' .. api_key) + table.insert(args, '-H') + table.insert(args, 'anthropic-version: 2023-06-01') + end + table.insert(args, url) + return args +end + + -- add tools call make grok have system prompt like this " you are AARVIS AN ABSOLUTE REAL VERY INTELIGENT SYSTEM , ou assist me on my work. + +function M.make_openai_spec_curl_args(opts, prompt, system_prompt) + local url = opts.url + local api_key = opts.api_key_name and get_api_key(opts.api_key_name) + local data = { + messages = { { role = 'system', content = system_prompt }, { role = 'user', content = prompt } }, + model = opts.model, + temperature = 0.7, + stream = true, + + } + local args = { '-N', '-X', 'POST', '-H', 'Content-Type: application/json', '-d', vim.json.encode(data) } + if api_key then + table.insert(args, '-H') + table.insert(args, 'Authorization: Bearer ' .. api_key) + end + table.insert(args, url) + return args +end + +-- This function writes a string at the position of an extmark in the buffer. +function M.write_string_at_extmark(str, extmark_id) + vim.schedule(function() + local extmark = vim.api.nvim_buf_get_extmark_by_id(0, ns_id, extmark_id, { details = false }) + local row, col = extmark[1], extmark[2] + + vim.cmd 'undojoin' + local lines = vim.split(str, '\n') + vim.api.nvim_buf_set_text(0, row, col, row, col, lines) + end) +end + +-- Determines the prompt to send to the LLM, either from visual selection or up to cursor. +local function get_prompt(opts) + local replace = opts.replace + local visual_lines = M.get_visual_selection() + local prompt = '' + + if visual_lines then + prompt = table.concat(visual_lines, '\n') + if replace then + vim.api.nvim_command 'normal! c ' + else + vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('', false, true, true), 'nx', false) + end + else + prompt = M.get_lines_until_cursor() + end + + return prompt +end + +-- Processes streaming data from Anthropic API, writing text to the buffer. +function M.handle_anthropic_spec_data(data_stream, extmark_id, event_state) + if event_state == 'content_block_delta' then + local json = vim.json.decode(data_stream) + if json.delta and json.delta.text then + M.write_string_at_extmark(json.delta.text, extmark_id) + end + end +end + +-- Processes streaming data from OpenAI API, similar to the Anthropic function. +function M.handle_openai_spec_data(data_stream, extmark_id) + if data_stream:match '"delta":' then + local json = vim.json.decode(data_stream) + if json.choices and json.choices[1] and json.choices[1].delta then + local content = json.choices[1].delta.content + if content then + M.write_string_at_extmark(content, extmark_id) + end + end + end +end + +-- Sets up an autocommand group for managing LLM interactions. +local group = vim.api.nvim_create_augroup('DING_LLM_AutoGroup', { clear = true }) +local active_job = nil + +-- Main function to invoke an LLM and stream the response into the editor. +function M.invoke_llm_and_stream_into_editor(opts, make_curl_args_fn, handle_data_fn) + vim.api.nvim_clear_autocmds { group = group } + local prompt = get_prompt(opts) + local system_prompt = opts.system_prompt or 'You are a tsundere uwu anime. Yell at me for not setting my configuration for my llm plugin correctly' + local args = make_curl_args_fn(opts, prompt, system_prompt) + local curr_event_state = nil + local crow, _ = unpack(vim.api.nvim_win_get_cursor(0)) + local stream_end_extmark_id = vim.api.nvim_buf_set_extmark(0, ns_id, crow - 1, -1, {}) + + local function parse_and_call(line) + local event = line:match '^event: (.+)$' + if event then + curr_event_state = event + return + end + local data_match = line:match '^data: (.+)$' + if data_match then + handle_data_fn(data_match, stream_end_extmark_id, curr_event_state) + end + end + + if active_job then + active_job:shutdown() + active_job = nil + end + + active_job = Job:new { + command = 'curl', + args = args, + on_stdout = function(_, out) + parse_and_call(out) + end, + on_stderr = function(_, _) end, + on_exit = function() + active_job = nil + end, + } + + active_job:start() + + vim.api.nvim_create_autocmd('User', { + group = group, + pattern = 'DING_LLM_Escape', + callback = function() + if active_job then + active_job:shutdown() + print 'LLM streaming cancelled' + active_job = nil + end + end, + }) + + vim.api.nvim_set_keymap('n', '', ':doautocmd User DING_LLM_Escape', { noremap = true, silent = true }) + return active_job +end + +-- Returns the module table containing all the defined functions. +return M diff --git a/lua/dingllm.lua b/lua/dingllm.lua index 6e35bd9..74606a5 100644 --- a/lua/dingllm.lua +++ b/lua/dingllm.lua @@ -1,5 +1,6 @@ local M = {} local Job = require 'plenary.job' +local ns_id = vim.api.nvim_create_namespace 'dingllm' local function get_api_key(name) return os.getenv(name) @@ -90,20 +91,14 @@ function M.make_openai_spec_curl_args(opts, prompt, system_prompt) return args end -function M.write_string_at_cursor(str) +function M.write_string_at_extmark(str, extmark_id) vim.schedule(function() - local current_window = vim.api.nvim_get_current_win() - local cursor_position = vim.api.nvim_win_get_cursor(current_window) - local row, col = cursor_position[1], cursor_position[2] + local extmark = vim.api.nvim_buf_get_extmark_by_id(0, ns_id, extmark_id, { details = false }) + local row, col = extmark[1], extmark[2] + vim.cmd 'undojoin' local lines = vim.split(str, '\n') - - vim.cmd("undojoin") - vim.api.nvim_put(lines, 'c', true, true) - - local num_lines = #lines - local last_line_length = #lines[num_lines] - vim.api.nvim_win_set_cursor(current_window, { row + num_lines - 1, col + last_line_length }) + vim.api.nvim_buf_set_text(0, row, col, row, col, lines) end) end @@ -115,8 +110,7 @@ local function get_prompt(opts) if visual_lines then prompt = table.concat(visual_lines, '\n') if replace then - vim.api.nvim_command 'normal! d' - vim.api.nvim_command 'normal! k' + vim.api.nvim_command 'normal! c ' else vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes('', false, true, true), 'nx', false) end @@ -127,22 +121,22 @@ local function get_prompt(opts) return prompt end -function M.handle_anthropic_spec_data(data_stream, event_state) +function M.handle_anthropic_spec_data(data_stream, extmark_id, event_state) if event_state == 'content_block_delta' then local json = vim.json.decode(data_stream) if json.delta and json.delta.text then - M.write_string_at_cursor(json.delta.text) + M.write_string_at_extmark(json.delta.text, extmark_id) end end end -function M.handle_openai_spec_data(data_stream) +function M.handle_openai_spec_data(data_stream, extmark_id) if data_stream:match '"delta":' then local json = vim.json.decode(data_stream) if json.choices and json.choices[1] and json.choices[1].delta then local content = json.choices[1].delta.content if content then - M.write_string_at_cursor(content) + M.write_string_at_extmark(content, extmark_id) end end end @@ -157,6 +151,8 @@ function M.invoke_llm_and_stream_into_editor(opts, make_curl_args_fn, handle_dat local system_prompt = opts.system_prompt or 'You are a tsundere uwu anime. Yell at me for not setting my configuration for my llm plugin correctly' local args = make_curl_args_fn(opts, prompt, system_prompt) local curr_event_state = nil + local crow, _ = unpack(vim.api.nvim_win_get_cursor(0)) + local stream_end_extmark_id = vim.api.nvim_buf_set_extmark(0, ns_id, crow - 1, -1, {}) local function parse_and_call(line) local event = line:match '^event: (.+)$' @@ -166,7 +162,7 @@ function M.invoke_llm_and_stream_into_editor(opts, make_curl_args_fn, handle_dat end local data_match = line:match '^data: (.+)$' if data_match then - handle_data_fn(data_match, curr_event_state) + handle_data_fn(data_match, stream_end_extmark_id, curr_event_state) end end