diff --git a/lua/prettier/cache.lua b/lua/prettier/cache.lua new file mode 100644 index 0000000..934f5b0 --- /dev/null +++ b/lua/prettier/cache.lua @@ -0,0 +1,24 @@ +local store = {} + +local cache = {} + +---@generic P : any[] +---@generic R : any +---@param fn fun(...: P): R +---@param make_key fun(...: P): string +---@return fun(...: P): R +function cache.wrap(fn, make_key) + store[fn] = {} + + return function(...) + local key = make_key(...) + + if store[fn][key] == nil then + store[fn][key] = fn(...) or false + end + + return store[fn][key] + end +end + +return cache diff --git a/lua/prettier/init.lua b/lua/prettier/init.lua index cd39837..3900d91 100644 --- a/lua/prettier/init.lua +++ b/lua/prettier/init.lua @@ -1,15 +1,53 @@ local options = require("prettier.options") local null_ls = require("prettier.null-ls") -local M = {} +local M = { + __ = {}, +} function M.setup(user_options) options.setup(user_options) - null_ls.setup() + vim.schedule(function() + null_ls.setup() + end) end function M.format(method) null_ls.format(method) end +function M.create_formatter(opts) + local command = opts.command + + M.__[command] = { + _fn = function(method) + if M.__[command].fn then + return M.__[command].fn(method) + end + end, + cmd = function(range) + if range > 0 then + M.__[command]._fn("textDocument/rangeFormatting") + else + M.__[command]._fn("textDocument/formatting") + end + end, + } + + vim.schedule(function() + local format = null_ls.create_formatter({ + bin = opts.bin, + bin_preference = opts.bin_preference, + cli_options = opts.cli_options, + ["null-ls"] = opts["null-ls"], + }) + + M.__[command].fn = format + + vim.cmd(string.format([[command! -range=%% %s :lua require("prettier").__["%s"].cmd()]], command, command)) + end) + + return M.__[command]._fn +end + return M diff --git a/lua/prettier/null-ls.lua b/lua/prettier/null-ls.lua index 5b82166..fab1386 100644 --- a/lua/prettier/null-ls.lua +++ b/lua/prettier/null-ls.lua @@ -9,23 +9,14 @@ local M = { _generator = nil, } -local function get_generator() - if not ok then - return - end - - if M._generator_initialized then - return M._generator - end - - M._generator_initialized = true +local function noop() end - if vim.tbl_count(options.get("filetypes")) == 0 then - return - end +local function create_generator(opts) + local bin = opts.bin + local cli_options = opts.cli_options or {} + local null_ls_options = opts["null-ls"] or {} - local bin = options.get("bin") --[[@as string]] - local command = utils.resolve_bin(bin) + local command = utils.resolve_bin(bin, opts.bin_preference) if not command then return @@ -34,8 +25,6 @@ local function get_generator() local format_cli_args = cli.get_base_args(bin) local range_format_cli_args = cli.get_base_args(bin) if cli.args.supports_options(bin) then - local cli_options = options.get("cli_options") - for _, arg in ipairs(cli.args.from_options(cli_options)) do table.insert(format_cli_args, arg) end @@ -50,7 +39,7 @@ local function get_generator() end end - M._generator = null_ls.formatter({ + return null_ls.formatter({ command = command, args = function(params) if params.lsp_method == "textDocument/formatting" then @@ -77,93 +66,110 @@ local function get_generator() return args end, to_stdin = true, - runtime_condition = options.get("null-ls.runtime_condition"), - timeout = options.get("null-ls.timeout"), + runtime_condition = null_ls_options["runtime_condition"], + timeout = null_ls_options["timeout"], }) - - return M._generator end -function M.format(method) +function M.create_formatter(opts) if not ok then - return + return noop end - method = method or "textDocument/formatting" - - local generator = get_generator() + local generator = create_generator({ + bin = opts.bin, + bin_preference = opts.bin_preference, + cli_options = opts.cli_options, + ["null-ls"] = opts["null-ls"], + }) if not generator then - return + return noop end - if not M._format then - local u = require("null-ls.utils") + local u = require("null-ls.utils") - M._format = function(original_params) - local bufnr = original_params.bufnr + local function format(original_params) + local method = original_params.range and "textDocument/rangeFormatting" or "textDocument/formatting" + local bufnr = original_params.bufnr - local temp_bufnr = vim.api.nvim_create_buf(false, true) - vim.api.nvim_buf_set_option(temp_bufnr, "eol", vim.api.nvim_buf_get_option(bufnr, "eol")) - vim.api.nvim_buf_set_option(temp_bufnr, "fileformat", vim.api.nvim_buf_get_option(bufnr, "fileformat")) - vim.api.nvim_buf_set_lines(temp_bufnr, 0, -1, false, vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)) + local temp_bufnr = vim.api.nvim_create_buf(false, true) + vim.api.nvim_buf_set_option(temp_bufnr, "eol", vim.api.nvim_buf_get_option(bufnr, "eol")) + vim.api.nvim_buf_set_option(temp_bufnr, "fileformat", vim.api.nvim_buf_get_option(bufnr, "fileformat")) + vim.api.nvim_buf_set_lines(temp_bufnr, 0, -1, false, vim.api.nvim_buf_get_lines(bufnr, 0, -1, false)) - local function callback() - local edits = require("null-ls.diff").compute_diff( - u.buf.content(bufnr), - u.buf.content(temp_bufnr), - u.get_line_ending(bufnr) - ) + local function callback() + local edits = require("null-ls.diff").compute_diff( + u.buf.content(bufnr), + u.buf.content(temp_bufnr), + u.get_line_ending(bufnr) + ) - vim.schedule(function() - vim.api.nvim_buf_delete(temp_bufnr, { force = true }) - end) + vim.schedule(function() + vim.api.nvim_buf_delete(temp_bufnr, { force = true }) + end) - local is_actual_edit = not (edits.newText == "" and edits.rangeLength == 0) + local is_actual_edit = not (edits.newText == "" and edits.rangeLength == 0) - if is_actual_edit then - vim.lsp.util.apply_text_edits({ edits }, bufnr, require("null-ls.client").get_offset_encoding()) - end + if is_actual_edit then + vim.lsp.util.apply_text_edits({ edits }, bufnr, require("null-ls.client").get_offset_encoding()) end - - require("null-ls.generators").run( - { generator }, - u.make_params(original_params, require("null-ls.methods").map[method]), - { - sequential = true, - postprocess = function(edit, params) - edit.row = edit.row or 1 - edit.col = edit.col or 1 - edit.end_row = edit.end_row or #params.content + 1 - edit.end_col = edit.end_col or 1 - - edit.range = u.range.to_lsp(edit) - edit.newText = edit.text - end, - after_each = function(edits) - vim.lsp.util.apply_text_edits(edits, temp_bufnr, require("null-ls.client").get_offset_encoding()) - end, - }, - callback - ) end + + require("null-ls.generators").run( + { generator }, + u.make_params(original_params, require("null-ls.methods").map[method]), + { + sequential = true, + postprocess = function(edit, params) + edit.row = edit.row or 1 + edit.col = edit.col or 1 + edit.end_row = edit.end_row or #params.content + 1 + edit.end_col = edit.end_col or 1 + + edit.range = u.range.to_lsp(edit) + edit.newText = edit.text + end, + after_each = function(edits) + vim.lsp.util.apply_text_edits(edits, temp_bufnr, require("null-ls.client").get_offset_encoding()) + end, + }, + callback + ) end - local bufnr = vim.api.nvim_get_current_buf() + local filetypes = opts.filetypes and utils.list_to_map(opts.filetypes) - local params = { - bufnr = bufnr, - method = method, - } + return function(method) + local bufnr = vim.api.nvim_get_current_buf() - if method == "textDocument/rangeFormatting" then - params.range = vim.lsp.util.make_given_range_params().range - end + if filetypes and not filetypes[vim.api.nvim_buf_get_option(bufnr, "filetype")] then + return + end + + local params = { + bufnr = bufnr, + method = method, + } + + if method == "textDocument/rangeFormatting" then + params.range = vim.lsp.util.make_given_range_params().range + end - M._format(params) + format(params) + end end function M.setup() + M.format = M.create_formatter({ + bin = options.get("bin"), + cli_options = options.get("cli_options"), + ["null-ls"] = { + runtime_condition = options.get("null-ls.runtime_condition"), + timeout = options.get("null-ls.timeout"), + }, + }) + if not ok then return end @@ -178,7 +184,18 @@ function M.setup() return end - local generator = get_generator() + if vim.tbl_count(options.get("filetypes")) == 0 then + return + end + + local generator = create_generator({ + bin = options.get("bin"), + cli_options = options.get("cli_options"), + ["null-ls"] = { + runtime_condition = options.get("null-ls.runtime_condition"), + timeout = options.get("null-ls.timeout"), + }, + }) if not generator then return diff --git a/lua/prettier/utils.lua b/lua/prettier/utils.lua index aeeab68..75008bc 100644 --- a/lua/prettier/utils.lua +++ b/lua/prettier/utils.lua @@ -1,3 +1,4 @@ +local cache = require("prettier.cache") local find_git_ancestor = require("lspconfig.util").find_git_ancestor local find_package_json_ancestor = require("lspconfig.util").find_package_json_ancestor local path_join = require("lspconfig.util").path.join @@ -26,23 +27,73 @@ function M.prettier_enabled() return config_file_exists() end ----@param cmd string ----@return nil|string -function M.resolve_bin(cmd) - local project_root = get_working_directory() +---@param _cwd string +---@param scope '"global"'|'"local"' +---@return string|false bin_dir +local function _get_bin_dir(_cwd, scope) + local cmd = "npm bin" + if scope == "global" then + cmd = cmd .. " --global" + end - if project_root then - local local_bin = path_join(project_root, "/node_modules/.bin", cmd) - if vim.fn.executable(local_bin) == 1 then - return local_bin - end + local result = vim.fn.systemlist(cmd) + if vim.fn.isdirectory(result[1]) == 1 then + return result[1] + end + + return false +end + +---@type fun(cwd: string, scope: '"global"'|'"local"'): string|false +local get_bin_dir = cache.wrap(_get_bin_dir, function(cwd, scope) + return scope .. "::" .. cwd +end) + +---@param name string +---@param scope '"global"'|'"local"' +---@return string|false bin +local function _get_bin_path(cwd, name, scope) + local bin_dir = get_bin_dir(cwd, scope) + if not bin_dir then + return false + end + + local bin = path_join(bin_dir, name) + if vim.fn.executable(bin) == 1 then + return bin + end + + if scope == "global" and vim.fn.executable(name) == 1 then + return vim.fn.exepath(name) end - if vim.fn.executable(cmd) == 1 then - return cmd + return false +end + +---@type fun(cwd: string, name: string, scope: '"global"'|'"local"'): string|false +local get_bin_path = cache.wrap(_get_bin_path, function(cwd, name, scope) + return scope .. "::" .. name .. "::" .. cwd +end) + +---@param name string +---@param preference? '"global"'|'"local"'|'"prefer-local"' +---@return string|false +function M.resolve_bin(name, preference) + local cwd = vim.fn.getcwd() + + preference = preference or "prefer-local" + + if preference == "global" then + return get_bin_path(cwd, name, "global") end - return nil + local bin = get_bin_path(cwd, name, "local") + + if bin or preference == "local" then + return bin + end + + return get_bin_path(cwd, name, "global") end function M.tbl_flatten(tbl, should_flatten, result, prefix, depth) @@ -64,4 +115,12 @@ function M.tbl_flatten(tbl, should_flatten, result, prefix, depth) return result end +function M.list_to_map(list) + local map = {} + for _, key in ipairs(list) do + map[key] = true + end + return map +end + return M