Skip to content

Commit

Permalink
Merge pull request #55 from frankroeder/chore/responsehandler
Browse files Browse the repository at this point in the history
chore: add response handler class
  • Loading branch information
frankroeder authored Sep 5, 2024
2 parents be975ee + 6218cff commit 2a15415
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 119 deletions.
26 changes: 11 additions & 15 deletions lua/parrot/chat_handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ local init_provider = require("parrot.provider").init_provider
local Spinner = require("parrot.spinner")
local Job = require("plenary.job")
local pft = require("plenary.filetype")
local ResponseHandler = require("parrot.response_handler")

local ChatHandler = {}

Expand Down Expand Up @@ -731,15 +732,9 @@ function ChatHandler:_chat_respond(params)
buf,
query_prov,
utils.prepare_payload(messages, model_obj.name, self.providers[query_prov.name].params["chat"]),
chatutils.create_handler(
self.queries,
buf,
win,
utils.last_content_line(buf),
true,
"",
not self.options.chat_free_cursor
),
ResponseHandler
:new(self.queries, buf, win, utils.last_content_line(buf), true, "", not self.options.chat_free_cursor)
:create_handler(),
vim.schedule_wrap(function(qid)
if self.options.enable_spinner and spinner then
spinner:stop()
Expand Down Expand Up @@ -783,7 +778,8 @@ function ChatHandler:_chat_respond(params)

-- prepare invisible buffer for the model to write to
local topic_buf = vim.api.nvim_create_buf(false, true)
local topic_handler = chatutils.create_handler(self.queries, topic_buf, nil, 0, false, "", false)
local topic_resp_handler = ResponseHandler:new(self.queries, topic_buf, nil, 0, false, "", false)
local topic_handler = topic_resp_handler:create_handler()
topic_prov:set_model(self.providers[topic_prov.name].topic.model)

local topic_spinner = nil
Expand Down Expand Up @@ -1257,21 +1253,21 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
-- delete selection
vim.api.nvim_buf_set_lines(buf, start_line - 1, end_line - 1, false, {})
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler()
elseif target == ui.Target.append then
-- move cursor to the end of the selection
vim.api.nvim_win_set_cursor(0, { end_line, 0 })
-- put newline after selection
vim.api.nvim_put({ "" }, "l", true, true)
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, end_line, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, end_line, true, prefix, cursor):create_handler()
elseif target == ui.Target.prepend then
-- move cursor to the start of the selection
vim.api.nvim_win_set_cursor(0, { start_line, 0 })
-- put newline before selection
vim.api.nvim_put({ "" }, "l", false, true)
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler()
elseif target == ui.Target.popup then
self:toggle_close(self._toggle_kind.popup)
-- create a new buffer
Expand Down Expand Up @@ -1299,7 +1295,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
-- better text wrapping
vim.api.nvim_command("setlocal wrap linebreak")
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", false)
handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", false):create_handler()
self:toggle_add(self._toggle_kind.popup, { win = win, buf = buf, close = popup_close })
elseif type(target) == "table" then
if target.type == ui.Target.new().type then
Expand Down Expand Up @@ -1331,7 +1327,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
local ft = target.filetype or filetype
vim.api.nvim_set_option_value("filetype", ft, { buf = buf })

handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", cursor)
handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", cursor):create_handler()
end

-- call the model and write the response
Expand Down
89 changes: 0 additions & 89 deletions lua/parrot/chat_utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,95 +26,6 @@ M.resolve_buf_target = function(params)
end
end

-- response handler
---@param buf number | nil # buffer to insert response into
---@param win number | nil # window to insert response into
---@param line number | nil # line to insert response into
---@param first_undojoin boolean | nil # whether to skip first undojoin
---@param prefix string | nil # prefix to insert before each response line
---@param cursor boolean # whether to move cursor to the end of the response
M.create_handler = function(queries, buf, win, line, first_undojoin, prefix, cursor)
buf = buf or vim.api.nvim_get_current_buf()
prefix = prefix or ""
local first_line = line or vim.api.nvim_win_get_cursor(win)[1] - 1
local finished_lines = 0
local skip_first_undojoin = not first_undojoin

local hl_handler_group = "PrtHandlerStandout"
vim.cmd("highlight default link " .. hl_handler_group .. " CursorLine")

local ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid())

local ex_id = vim.api.nvim_buf_set_extmark(buf, ns_id, first_line, 0, {
strict = false,
right_gravity = false,
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
local qt = queries:get(qid)
if not qt then
return
end
-- if buf is not valid, stop
if not vim.api.nvim_buf_is_valid(buf) then
return
end
-- undojoin takes previous change into account, so skip it for the first chunk
if skip_first_undojoin then
skip_first_undojoin = false
else
utils.undojoin(buf)
end

if not qt.ns_id then
qt.ns_id = ns_id
end

if not qt.ex_id then
qt.ex_id = ex_id
end

first_line = vim.api.nvim_buf_get_extmark_by_id(buf, ns_id, ex_id, {})[1]

-- clean previous response
local line_count = #vim.split(response, "\n")
vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + line_count, false, {})

-- append new response
response = response .. chunk
utils.undojoin(buf)

-- prepend prefix to each line
local lines = vim.split(response, "\n")
for i, l in ipairs(lines) do
lines[i] = prefix .. l
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
end

vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines)

local new_finished_lines = math.max(0, #lines - 1)
for i = finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
end
finished_lines = new_finished_lines

local end_line = first_line + #vim.split(response, "\n")
qt.first_line = first_line
qt.last_line = end_line - 1

-- move cursor to the end of the response
if cursor then
utils.cursor_to_line(end_line, buf, win)
end
end)
end

---@param buf number | nil
M.prep_md = function(buf)
vim.api.nvim_set_option_value("swapfile", false, { buf = buf })
Expand Down
139 changes: 139 additions & 0 deletions lua/parrot/response_handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
local utils = require("parrot.utils")

---@class ResponseHandler
---@field buffer number
---@field window number
---@field ns_id number
---@field ex_id number
---@field first_line number
---@field finished_lines number
---@field response string
---@field prefix string
---@field cursor boolean
---@field hl_handler_group string
local ResponseHandler = {}
ResponseHandler.__index = ResponseHandler

---Creates a new ResponseHandler
---@param queries table
---@param buffer number|nil
---@param window number|nil
---@param line number|nil
---@param first_undojoin boolean|nil
---@param prefix string|nil
---@param cursor boolean
---@return ResponseHandler
function ResponseHandler:new(queries, buffer, window, line, first_undojoin, prefix, cursor)
local self = setmetatable({}, ResponseHandler)
self.buffer = buffer or vim.api.nvim_get_current_buf()
self.window = window or vim.api.nvim_get_current_win()
self.prefix = prefix or ""
self.cursor = cursor or false
self.first_line = line or (self.window and vim.api.nvim_win_get_cursor(self.window)[1] - 1 or 0)
self.finished_lines = 0
self.response = ""
self.queries = queries
self.skip_first_undojoin = not first_undojoin

self.hl_handler_group = "PrtHandlerStandout"
vim.cmd("highlight default link " .. self.hl_handler_group .. " CursorLine")

self.ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid())
self.ex_id = vim.api.nvim_buf_set_extmark(self.buffer, self.ns_id, self.first_line, 0, {
strict = false,
right_gravity = false,
})

return self
end

---Handles a chunk of response
---@param qid any
---@param chunk string
function ResponseHandler:handle_chunk(qid, chunk)
local qt = self.queries:get(qid)
if not qt or not vim.api.nvim_buf_is_valid(self.buffer) then
return
end

if not self.skip_first_undojoin then
utils.undojoin(self.buffer)
end
self.skip_first_undojoin = false

qt.ns_id = qt.ns_id or self.ns_id
qt.ex_id = qt.ex_id or self.ex_id

self.first_line = vim.api.nvim_buf_get_extmark_by_id(self.buffer, self.ns_id, self.ex_id, {})[1]

local line_count = #vim.split(self.response, "\n")
vim.api.nvim_buf_set_lines(self.buffer, self.first_line + self.finished_lines, self.first_line + line_count, false, {})

self:update_response(chunk)
self:update_buffer()
self:update_highlighting(qt)
self:update_query_object(qt)
self:move_cursor()
end

---Updates the response with a new chunk
---@param chunk string
function ResponseHandler:update_response(chunk)
if chunk ~= nil then
self.response = self.response .. chunk
utils.undojoin(self.buffer)
end
end

---Updates the buffer with the current response
function ResponseHandler:update_buffer()
local lines = vim.split(self.response, "\n")
local prefixed_lines = vim.tbl_map(function(l)
return self.prefix .. l
end, lines)

vim.api.nvim_buf_set_lines(
self.buffer,
self.first_line + self.finished_lines,
self.first_line + self.finished_lines,
false,
vim.list_slice(prefixed_lines, self.finished_lines + 1)
)
end

---Updates the highlighting for new lines
---@param qt table
function ResponseHandler:update_highlighting(qt)
local lines = vim.split(self.response, "\n")
local new_finished_lines = math.max(0, #lines - 1)
for i = self.finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(self.buffer, qt.ns_id, self.hl_handler_group, self.first_line + i, 0, -1)
end
self.finished_lines = new_finished_lines
end

---Updates the query object with new line information
---@param qt table
function ResponseHandler:update_query_object(qt)
local end_line = self.first_line + #vim.split(self.response, "\n")
qt.first_line = self.first_line
qt.last_line = end_line - 1
end

---Moves the cursor to the end of the response if needed
function ResponseHandler:move_cursor()
if self.cursor then
local end_line = self.first_line + #vim.split(self.response, "\n")
utils.cursor_to_line(end_line, self.buffer, self.window)
end
end

---Creates a handler function
---@return function
function ResponseHandler:create_handler()
return vim.schedule_wrap(function(qid, chunk)
self:handle_chunk(qid, chunk)
end)
end

return ResponseHandler
15 changes: 0 additions & 15 deletions tests/parrot/chat_utils_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,6 @@ describe("chat_utils", function()
end)
end)

describe("create_handler", function()
it("should create a handler function", function()
local mock_queries = {
get = function()
return {}
end,
}
local buf = vim.api.nvim_create_buf(false, true)
local win = vim.api.nvim_get_current_win()
local handler = chat_utils.create_handler(mock_queries, buf, win)
assert.is_function(handler)
vim.api.nvim_buf_delete(buf, { force = true })
end)
end)

describe("prep_md", function()
it("should set buffer and window options correctly", function()
async.run(function()
Expand Down
Loading

0 comments on commit 2a15415

Please sign in to comment.