Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add response handler class #55

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions lua/parrot/chat_handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ local init_provider = require("parrot.provider").init_provider
local Spinner = require("parrot.spinner")
local Job = require("plenary.job")
local pft = require("plenary.filetype")
local ResponseHandler = require("parrot.response_handler")

local ChatHandler = {}

Expand Down Expand Up @@ -731,15 +732,9 @@ function ChatHandler:_chat_respond(params)
buf,
query_prov,
utils.prepare_payload(messages, model_obj.name, self.providers[query_prov.name].params["chat"]),
chatutils.create_handler(
self.queries,
buf,
win,
utils.last_content_line(buf),
true,
"",
not self.options.chat_free_cursor
),
ResponseHandler
:new(self.queries, buf, win, utils.last_content_line(buf), true, "", not self.options.chat_free_cursor)
:create_handler(),
vim.schedule_wrap(function(qid)
if self.options.enable_spinner and spinner then
spinner:stop()
Expand Down Expand Up @@ -783,7 +778,8 @@ function ChatHandler:_chat_respond(params)

-- prepare invisible buffer for the model to write to
local topic_buf = vim.api.nvim_create_buf(false, true)
local topic_handler = chatutils.create_handler(self.queries, topic_buf, nil, 0, false, "", false)
local topic_resp_handler = ResponseHandler:new(self.queries, topic_buf, nil, 0, false, "", false)
local topic_handler = topic_resp_handler:create_handler()
topic_prov:set_model(self.providers[topic_prov.name].topic.model)

local topic_spinner = nil
Expand Down Expand Up @@ -1257,21 +1253,21 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
-- delete selection
vim.api.nvim_buf_set_lines(buf, start_line - 1, end_line - 1, false, {})
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler()
elseif target == ui.Target.append then
-- move cursor to the end of the selection
vim.api.nvim_win_set_cursor(0, { end_line, 0 })
-- put newline after selection
vim.api.nvim_put({ "" }, "l", true, true)
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, end_line, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, end_line, true, prefix, cursor):create_handler()
elseif target == ui.Target.prepend then
-- move cursor to the start of the selection
vim.api.nvim_win_set_cursor(0, { start_line, 0 })
-- put newline before selection
vim.api.nvim_put({ "" }, "l", false, true)
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, start_line - 1, true, prefix, cursor)
handler = ResponseHandler:new(self.queries, buf, win, start_line - 1, true, prefix, cursor):create_handler()
elseif target == ui.Target.popup then
self:toggle_close(self._toggle_kind.popup)
-- create a new buffer
Expand Down Expand Up @@ -1299,7 +1295,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
-- better text wrapping
vim.api.nvim_command("setlocal wrap linebreak")
-- prepare handler
handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", false)
handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", false):create_handler()
self:toggle_add(self._toggle_kind.popup, { win = win, buf = buf, close = popup_close })
elseif type(target) == "table" then
if target.type == ui.Target.new().type then
Expand Down Expand Up @@ -1331,7 +1327,7 @@ function ChatHandler:prompt(params, target, model_obj, prompt, template)
local ft = target.filetype or filetype
vim.api.nvim_set_option_value("filetype", ft, { buf = buf })

handler = chatutils.create_handler(self.queries, buf, win, 0, false, "", cursor)
handler = ResponseHandler:new(self.queries, buf, win, 0, false, "", cursor):create_handler()
end

-- call the model and write the response
Expand Down
89 changes: 0 additions & 89 deletions lua/parrot/chat_utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,95 +26,6 @@ M.resolve_buf_target = function(params)
end
end

-- response handler
---@param buf number | nil # buffer to insert response into
---@param win number | nil # window to insert response into
---@param line number | nil # line to insert response into
---@param first_undojoin boolean | nil # whether to skip first undojoin
---@param prefix string | nil # prefix to insert before each response line
---@param cursor boolean # whether to move cursor to the end of the response
M.create_handler = function(queries, buf, win, line, first_undojoin, prefix, cursor)
buf = buf or vim.api.nvim_get_current_buf()
prefix = prefix or ""
local first_line = line or vim.api.nvim_win_get_cursor(win)[1] - 1
local finished_lines = 0
local skip_first_undojoin = not first_undojoin

local hl_handler_group = "PrtHandlerStandout"
vim.cmd("highlight default link " .. hl_handler_group .. " CursorLine")

local ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid())

local ex_id = vim.api.nvim_buf_set_extmark(buf, ns_id, first_line, 0, {
strict = false,
right_gravity = false,
})

local response = ""
return vim.schedule_wrap(function(qid, chunk)
local qt = queries:get(qid)
if not qt then
return
end
-- if buf is not valid, stop
if not vim.api.nvim_buf_is_valid(buf) then
return
end
-- undojoin takes previous change into account, so skip it for the first chunk
if skip_first_undojoin then
skip_first_undojoin = false
else
utils.undojoin(buf)
end

if not qt.ns_id then
qt.ns_id = ns_id
end

if not qt.ex_id then
qt.ex_id = ex_id
end

first_line = vim.api.nvim_buf_get_extmark_by_id(buf, ns_id, ex_id, {})[1]

-- clean previous response
local line_count = #vim.split(response, "\n")
vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + line_count, false, {})

-- append new response
response = response .. chunk
utils.undojoin(buf)

-- prepend prefix to each line
local lines = vim.split(response, "\n")
for i, l in ipairs(lines) do
lines[i] = prefix .. l
end

local unfinished_lines = {}
for i = finished_lines + 1, #lines do
table.insert(unfinished_lines, lines[i])
end

vim.api.nvim_buf_set_lines(buf, first_line + finished_lines, first_line + finished_lines, false, unfinished_lines)

local new_finished_lines = math.max(0, #lines - 1)
for i = finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(buf, qt.ns_id, hl_handler_group, first_line + i, 0, -1)
end
finished_lines = new_finished_lines

local end_line = first_line + #vim.split(response, "\n")
qt.first_line = first_line
qt.last_line = end_line - 1

-- move cursor to the end of the response
if cursor then
utils.cursor_to_line(end_line, buf, win)
end
end)
end

---@param buf number | nil
M.prep_md = function(buf)
vim.api.nvim_set_option_value("swapfile", false, { buf = buf })
Expand Down
139 changes: 139 additions & 0 deletions lua/parrot/response_handler.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
local utils = require("parrot.utils")

---@class ResponseHandler
---@field buffer number
---@field window number
---@field ns_id number
---@field ex_id number
---@field first_line number
---@field finished_lines number
---@field response string
---@field prefix string
---@field cursor boolean
---@field hl_handler_group string
local ResponseHandler = {}
ResponseHandler.__index = ResponseHandler

---Creates a new ResponseHandler
---@param queries table
---@param buffer number|nil
---@param window number|nil
---@param line number|nil
---@param first_undojoin boolean|nil
---@param prefix string|nil
---@param cursor boolean
---@return ResponseHandler
function ResponseHandler:new(queries, buffer, window, line, first_undojoin, prefix, cursor)
local self = setmetatable({}, ResponseHandler)
self.buffer = buffer or vim.api.nvim_get_current_buf()
self.window = window or vim.api.nvim_get_current_win()
self.prefix = prefix or ""
self.cursor = cursor or false
self.first_line = line or (self.window and vim.api.nvim_win_get_cursor(self.window)[1] - 1 or 0)
self.finished_lines = 0
self.response = ""
self.queries = queries
self.skip_first_undojoin = not first_undojoin

self.hl_handler_group = "PrtHandlerStandout"
vim.cmd("highlight default link " .. self.hl_handler_group .. " CursorLine")

self.ns_id = vim.api.nvim_create_namespace("PrtHandler_" .. utils.uuid())
self.ex_id = vim.api.nvim_buf_set_extmark(self.buffer, self.ns_id, self.first_line, 0, {
strict = false,
right_gravity = false,
})

return self
end

---Handles a chunk of response
---@param qid any
---@param chunk string
function ResponseHandler:handle_chunk(qid, chunk)
local qt = self.queries:get(qid)
if not qt or not vim.api.nvim_buf_is_valid(self.buffer) then
return
end

if not self.skip_first_undojoin then
utils.undojoin(self.buffer)
end
self.skip_first_undojoin = false

qt.ns_id = qt.ns_id or self.ns_id
qt.ex_id = qt.ex_id or self.ex_id

self.first_line = vim.api.nvim_buf_get_extmark_by_id(self.buffer, self.ns_id, self.ex_id, {})[1]

local line_count = #vim.split(self.response, "\n")
vim.api.nvim_buf_set_lines(self.buffer, self.first_line + self.finished_lines, self.first_line + line_count, false, {})

self:update_response(chunk)
self:update_buffer()
self:update_highlighting(qt)
self:update_query_object(qt)
self:move_cursor()
end

---Updates the response with a new chunk
---@param chunk string
function ResponseHandler:update_response(chunk)
if chunk ~= nil then
self.response = self.response .. chunk
utils.undojoin(self.buffer)
end
end

---Updates the buffer with the current response
function ResponseHandler:update_buffer()
local lines = vim.split(self.response, "\n")
local prefixed_lines = vim.tbl_map(function(l)
return self.prefix .. l
end, lines)

vim.api.nvim_buf_set_lines(
self.buffer,
self.first_line + self.finished_lines,
self.first_line + self.finished_lines,
false,
vim.list_slice(prefixed_lines, self.finished_lines + 1)
)
end

---Updates the highlighting for new lines
---@param qt table
function ResponseHandler:update_highlighting(qt)
local lines = vim.split(self.response, "\n")
local new_finished_lines = math.max(0, #lines - 1)
for i = self.finished_lines, new_finished_lines do
vim.api.nvim_buf_add_highlight(self.buffer, qt.ns_id, self.hl_handler_group, self.first_line + i, 0, -1)
end
self.finished_lines = new_finished_lines
end

---Updates the query object with new line information
---@param qt table
function ResponseHandler:update_query_object(qt)
local end_line = self.first_line + #vim.split(self.response, "\n")
qt.first_line = self.first_line
qt.last_line = end_line - 1
end

---Moves the cursor to the end of the response if needed
function ResponseHandler:move_cursor()
if self.cursor then
local end_line = self.first_line + #vim.split(self.response, "\n")
utils.cursor_to_line(end_line, self.buffer, self.window)
end
end

---Creates a handler function
---@return function
function ResponseHandler:create_handler()
return vim.schedule_wrap(function(qid, chunk)
self:handle_chunk(qid, chunk)
end)
end

return ResponseHandler
15 changes: 0 additions & 15 deletions tests/parrot/chat_utils_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,6 @@ describe("chat_utils", function()
end)
end)

describe("create_handler", function()
it("should create a handler function", function()
local mock_queries = {
get = function()
return {}
end,
}
local buf = vim.api.nvim_create_buf(false, true)
local win = vim.api.nvim_get_current_win()
local handler = chat_utils.create_handler(mock_queries, buf, win)
assert.is_function(handler)
vim.api.nvim_buf_delete(buf, { force = true })
end)
end)

describe("prep_md", function()
it("should set buffer and window options correctly", function()
async.run(function()
Expand Down
Loading
Loading