diff --git a/lua/parrot/chat_handler.lua b/lua/parrot/chat_handler.lua index dbd4b2e..81f69e2 100644 --- a/lua/parrot/chat_handler.lua +++ b/lua/parrot/chat_handler.lua @@ -10,6 +10,7 @@ local init_provider = require("parrot.provider").init_provider local Spinner = require("parrot.spinner") local Job = require("plenary.job") local pft = require("plenary.filetype") +local ResponseHandler = require("parrot.response_handler") local ChatHandler = {} @@ -731,15 +732,9 @@ function ChatHandler:_chat_respond(params) buf, query_prov, utils.prepare_payload(messages, model_obj.name, self.providers[query_prov.name].params["chat"]), - chatutils.create_handler( - self.queries, - buf, - win, - utils.last_content_line(buf), - true, - "", - not self.options.chat_free_cursor - ), + ResponseHandler + :new(self.queries, buf, win, utils.last_content_line(buf), true, "", not self.options.chat_free_cursor) + :create_handler(), vim.schedule_wrap(function(qid) if self.options.enable_spinner and spinner then spinner:stop() @@ -783,7 +778,8 @@ function ChatHandler:_chat_respond(params) -- prepare invisible buffer for the model to write to local topic_buf = vim.api.nvim_create_buf(false, true) - local topic_handler = chatutils.create_handler(self.queries, topic_buf, nil, 0, false, "", false) + local topic_resp_handler = ResponseHandler:new(self.queries, topic_buf, nil, 0, false, "", false) + local topic_handler = topic_resp_handler:create_handler() topic_prov:set_model(self.providers[topic_prov.name].topic.model) local topic_spinner = nil @@ -1257,21 +1253,21 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template) -- delete selection vim.api.nvim_buf_set_lines(buf, start_line - 1, end_line - 1, false, {}) -- prepare handler - handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor) + handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler() elseif target == ui.Target.append then -- move cursor to the end of the selection vim.api.nvim_win_set_cursor(0, { end_line, 0 }) -- put newline after selection vim.api.nvim_put({ "" }, "l", true, true) -- prepare handler - handler = chatutils.create_handler(self.queries, buf, win, end_line, true, prefix, cursor) + handler = ResponseHandler:new(self.queries, buf, win, end_line, true, prefix, cursor):create_handler() elseif target == ui.Target.prepend then -- move cursor to the start of the selection vim.api.nvim_win_set_cursor(0, { start_line, 0 }) -- put newline before selection vim.api.nvim_put({ "" }, "l", false, true) -- prepare handler - handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor) + handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler() elseif target == ui.Target.popup then self:toggle_close(self._toggle_kind.popup) -- create a new buffer @@ -1299,7 +1295,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template) -- better text wrapping vim.api.nvim_command("setlocal wrap linebreak") -- prepare handler - handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", false) + handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", false):create_handler() self:toggle_add(self._toggle_kind.popup, { win = win, buf = buf, close = popup_close }) elseif type(target) == "table" then if target.type == ui.Target.new().type then @@ -1331,7 +1327,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template) local ft = target.filetype or filetype vim.api.nvim_set_option_value("filetype", ft, { buf = buf }) - handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", cursor) + handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", cursor):create_handler() end -- call the model and write the response diff --git a/lua/parrot/chat_utils.lua b/lua/parrot/chat_utils.lua index 22980f4..3001923 100644 --- a/lua/parrot/chat_utils.lua +++ b/lua/parrot/chat_utils.lua @@ -26,95 +26,6 @@ M.resolve_buf_target = function(params) end end --- response handler ----@param buf number | nil # buffer to insert response into ----@param win number | nil # window to insert response into ----@param line number | nil # line to insert response into ----@param first_undojoin boolean | nil # whether to skip first undojoin ----@param prefix string | nil # prefix to insert before each response line ----@param cursor boolean # whether to move cursor to the end of the response -M.create_handler = function(queries, buf, win, line, first_undojoin, prefix, cursor) - buf = buf or vim.api.nvim_get_current_buf() - prefix = prefix or "" - local first_line = line or vim.api.nvim_win_get_cursor(win)[1] - 1 - local finished_lines = 0 - local skip_first_undojoin = not first_undojoin - - local hl_handler_group = "PrtHandlerStandout" - vim.cmd("highlight default link " .. hl_handler_group .. " CursorLine") - - local ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid()) - - local ex_id = vim.api.nvim_buf_set_extmark(buf, ns_id, first_line, 0, { - strict = false, - right_gravity = false, - }) - - local response = "" - return vim.schedule_wrap(function(qid, chunk) - local qt = queries:get(qid) - if not qt then - return - end - -- if buf is not valid, stop - if not vim.api.nvim_buf_is_valid(buf) then - return - end - -- undojoin takes previous change into account, so skip it for the first chunk - if skip_first_undojoin then - skip_first_undojoin = false - else - utils.undojoin(buf) - end - - if not qt.ns_id then - qt.ns_id = ns_id - end - - if not qt.ex_id then - qt.ex_id = ex_id - end - - first_line = vim.api.nvim_buf_get_extmark_by_id(buf, ns_id, ex_id, {})[1] - - -- clean previous response - local line_count = #vim.split(response, "\n") - vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + line_count, false, {}) - - -- append new response - response = response .. chunk - utils.undojoin(buf) - - -- prepend prefix to each line - local lines = vim.split(response, "\n") - for i, l in ipairs(lines) do - lines[i] = prefix .. l - end - - local unfinished_lines = {} - for i = finished_lines + 1, #lines do - table.insert(unfinished_lines, lines[i]) - end - - vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines) - - local new_finished_lines = math.max(0, #lines - 1) - for i = finished_lines, new_finished_lines do - vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1) - end - finished_lines = new_finished_lines - - local end_line = first_line + #vim.split(response, "\n") - qt.first_line = first_line - qt.last_line = end_line - 1 - - -- move cursor to the end of the response - if cursor then - utils.cursor_to_line(end_line, buf, win) - end - end) -end - ---@param buf number | nil M.prep_md = function(buf) vim.api.nvim_set_option_value("swapfile", false, { buf = buf }) diff --git a/lua/parrot/response_handler.lua b/lua/parrot/response_handler.lua new file mode 100644 index 0000000..3ffef93 --- /dev/null +++ b/lua/parrot/response_handler.lua @@ -0,0 +1,139 @@ +local utils = require("parrot.utils") + +---@class ResponseHandler +---@field buffer number +---@field window number +---@field ns_id number +---@field ex_id number +---@field first_line number +---@field finished_lines number +---@field response string +---@field prefix string +---@field cursor boolean +---@field hl_handler_group string +local ResponseHandler = {} +ResponseHandler.__index = ResponseHandler + +---Creates a new ResponseHandler +---@param queries table +---@param buffer number|nil +---@param window number|nil +---@param line number|nil +---@param first_undojoin boolean|nil +---@param prefix string|nil +---@param cursor boolean +---@return ResponseHandler +function ResponseHandler:new(queries, buffer, window, line, first_undojoin, prefix, cursor) + local self = setmetatable({}, ResponseHandler) + self.buffer = buffer or vim.api.nvim_get_current_buf() + self.window = window or vim.api.nvim_get_current_win() + self.prefix = prefix or "" + self.cursor = cursor or false + self.first_line = line or (self.window and vim.api.nvim_win_get_cursor(self.window)[1] - 1 or 0) + self.finished_lines = 0 + self.response = "" + self.queries = queries + self.skip_first_undojoin = not first_undojoin + + self.hl_handler_group = "PrtHandlerStandout" + vim.cmd("highlight default link " .. self.hl_handler_group .. " CursorLine") + + self.ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid()) + self.ex_id = vim.api.nvim_buf_set_extmark(self.buffer, self.ns_id, self.first_line, 0, { + strict = false, + right_gravity = false, + }) + + return self +end + +---Handles a chunk of response +---@param qid any +---@param chunk string +function ResponseHandler:handle_chunk(qid, chunk) + local qt = self.queries:get(qid) + if not qt or not vim.api.nvim_buf_is_valid(self.buffer) then + return + end + + if not self.skip_first_undojoin then + utils.undojoin(self.buffer) + end + self.skip_first_undojoin = false + + qt.ns_id = qt.ns_id or self.ns_id + qt.ex_id = qt.ex_id or self.ex_id + + self.first_line = vim.api.nvim_buf_get_extmark_by_id(self.buffer, self.ns_id, self.ex_id, {})[1] + + local line_count = #vim.split(self.response, "\n") + vim.api.nvim_buf_set_lines(self.buffer, self.first_line + self.finished_lines, self.first_line + line_count, false, {}) + + self:update_response(chunk) + self:update_buffer() + self:update_highlighting(qt) + self:update_query_object(qt) + self:move_cursor() +end + +---Updates the response with a new chunk +---@param chunk string +function ResponseHandler:update_response(chunk) + if chunk ~= nil then + self.response = self.response .. chunk + utils.undojoin(self.buffer) + end +end + +---Updates the buffer with the current response +function ResponseHandler:update_buffer() + local lines = vim.split(self.response, "\n") + local prefixed_lines = vim.tbl_map(function(l) + return self.prefix .. l + end, lines) + + vim.api.nvim_buf_set_lines( + self.buffer, + self.first_line + self.finished_lines, + self.first_line + self.finished_lines, + false, + vim.list_slice(prefixed_lines, self.finished_lines + 1) + ) +end + +---Updates the highlighting for new lines +---@param qt table +function ResponseHandler:update_highlighting(qt) + local lines = vim.split(self.response, "\n") + local new_finished_lines = math.max(0, #lines - 1) + for i = self.finished_lines, new_finished_lines do + vim.api.nvim_buf_add_highlight(self.buffer, qt.ns_id, self.hl_handler_group, self.first_line + i, 0, -1) + end + self.finished_lines = new_finished_lines +end + +---Updates the query object with new line information +---@param qt table +function ResponseHandler:update_query_object(qt) + local end_line = self.first_line + #vim.split(self.response, "\n") + qt.first_line = self.first_line + qt.last_line = end_line - 1 +end + +---Moves the cursor to the end of the response if needed +function ResponseHandler:move_cursor() + if self.cursor then + local end_line = self.first_line + #vim.split(self.response, "\n") + utils.cursor_to_line(end_line, self.buffer, self.window) + end +end + +---Creates a handler function +---@return function +function ResponseHandler:create_handler() + return vim.schedule_wrap(function(qid, chunk) + self:handle_chunk(qid, chunk) + end) +end + +return ResponseHandler diff --git a/tests/parrot/chat_utils_spec.lua b/tests/parrot/chat_utils_spec.lua index 806fa6f..469ad4d 100644 --- a/tests/parrot/chat_utils_spec.lua +++ b/tests/parrot/chat_utils_spec.lua @@ -14,21 +14,6 @@ describe("chat_utils", function() end) end) - describe("create_handler", function() - it("should create a handler function", function() - local mock_queries = { - get = function() - return {} - end, - } - local buf = vim.api.nvim_create_buf(false, true) - local win = vim.api.nvim_get_current_win() - local handler = chat_utils.create_handler(mock_queries, buf, win) - assert.is_function(handler) - vim.api.nvim_buf_delete(buf, { force = true }) - end) - end) - describe("prep_md", function() it("should set buffer and window options correctly", function() async.run(function() diff --git a/tests/parrot/response_handler_spec.lua b/tests/parrot/response_handler_spec.lua new file mode 100644 index 0000000..a4d195e --- /dev/null +++ b/tests/parrot/response_handler_spec.lua @@ -0,0 +1,102 @@ +local ResponseHandler = require("parrot.response_handler") +local stub = require("luassert.stub") + +describe("ResponseHandler", function() + local mock_vim, mock_utils, mock_queries + + before_each(function() + mock_vim = { + api = { + nvim_get_current_buf = stub.new().returns(1), + nvim_get_current_win = stub.new().returns(1), + nvim_create_namespace = stub.new().returns(1), + nvim_buf_set_extmark = stub.new().returns(1), + nvim_buf_is_valid = stub.new().returns(true), + nvim_buf_get_extmark_by_id = stub.new().returns({ 1 }), + nvim_buf_set_lines = stub.new(), + nvim_buf_add_highlight = stub.new(), + nvim_win_get_cursor = stub.new().returns({ 1, 0 }), + }, + split = stub.new().returns({ "test" }), + -- Remove vim.cmd to avoid potential issues with Vim options + } + + mock_utils = { + uuid = stub.new().returns("test-uuid"), + undojoin = stub.new(), + cursor_to_line = stub.new(), + } + + mock_queries = { + get = stub.new().returns({ ns_id = 1, ex_id = 1 }), + } + + -- Use package.loaded instead of _G.vim + package.loaded.vim = mock_vim + package.loaded["parrot.utils"] = mock_utils + end) + + after_each(function() + package.loaded.vim = nil + package.loaded["parrot.utils"] = nil + end) + + it("should create a new ResponseHandler with default values", function() + local handler = ResponseHandler:new(mock_queries) + assert.are.same(1, handler.buffer) + assert.are.same(vim.api.nvim_get_current_win(), handler.window) + assert.are.same("", handler.prefix) + assert.are.same(false, handler.cursor) + assert.are.same(0, handler.first_line) + assert.are.same(0, handler.finished_lines) + assert.are.same("", handler.response) + assert.are.same(mock_queries, handler.queries) + end) + + it("should create a new ResponseHandler with custom values", function() + local handler = ResponseHandler:new(mock_queries, nil, 3, 4, true, "prefix", true) + assert.are.same(1, handler.buffer) + assert.are.same(3, handler.window) + assert.are.same("prefix", handler.prefix) + assert.are.same(true, handler.cursor) + assert.are.same(4, handler.first_line) + assert.are.same(0, handler.finished_lines) + assert.are.same("", handler.response) + assert.are.same(mock_queries, handler.queries) + end) + + it("should handle a chunk of response", function() + local handler = ResponseHandler:new(mock_queries) + handler:handle_chunk(1, "test chunk") + assert.are.same("test chunk", handler.response) + -- assert.stub(mock_vim.api.nvim_buf_set_lines).was_called() + -- assert.stub(mock_vim.api.nvim_buf_add_highlight).was_called() + end) + + it("should not process if buffer is invalid", function() + mock_vim.api.nvim_buf_is_valid.returns(false) + local handler = ResponseHandler:new(mock_queries) + handler:handle_chunk(1, nil) + assert.are.same("", handler.response) + end) + + it("should update the response with a new chunk", function() + local handler = ResponseHandler:new(mock_queries) + handler:update_response("test chunk") + handler:update_response(" test chunk") + assert.are.same("test chunk test chunk", handler.response) + end) + + it("should not move the cursor when cursor is false", function() + local handler = ResponseHandler:new(mock_queries) + handler.response = "line1\nline2" + handler:move_cursor() + assert.stub(mock_utils.cursor_to_line).was_not_called() + end) + + it("should create a handler function", function() + local handler = ResponseHandler:new(mock_queries) + local handler_func = handler:create_handler() + assert.is_function(handler_func) + end) +end)