Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle limited time key, Lua function for obtaining keys #407

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ to store the credential in clear-text in a configuration file.

As an alternative to providing the API key via the `OPENAI_API_KEY` environment
variable, the user is encouraged to use the `api_key_cmd` configuration option.


#### api_key_cmd using script
The `api_key_cmd` configuration option takes a string, which is executed at
startup, and whose output is used as the API key.

Expand All @@ -180,6 +183,71 @@ Note that the `api_key_cmd` arguments are split by whitespace. If you need
whitespace inside an argument (for example to reference a path with spaces),
you can wrap it in a separate script.

#### api_key_cmd using lua function

Here is another way to provide the key by using Lua function

```lua
local get_api_key = function (callback)
local job = require("plenary.job")
local url = "https://my-enterprise.com/key/management.url"
local value = "some-world-readable-client-id"

job:new({
command = "curl",
args = {
url,
"--silent", "-X", "POST",
"-H", "Accept: */*",
"-H", "Content-Type: application/x-www-form-urlencoded",
"-H", "Authorization: Basic " .. value,
"-d", "grant_type=client_credentials",
},
on_exit = vim.schedule_wrap(function(response, exit_code)
vim.notify("Key job: exitcode " .. vim.inspect(exit_code) .. ", Key: " .. vim.inspect(response:result()), vim.log.levels.INFO)

if exit_code ~= 0 then
-- curl failed
vim.notify("Key: failed to obtain key" .. vim.inspect(response), vim.log.levels.ERROR)
return
end

-- Get stdout which is a json string
local result = table.concat(response:result(), "\n")

local ok, json = pcall(vim.json.decode, result)
if not ok or not json then
vim.notify("Key: error decoding response " .. vim.inspect(result), vim.log.levels.ERROR)
return
end

if json and json["access_token"] then
-- Notify the callback with the key and the valid duration
if callback then
callback(json["access_token"], json["expires_in"] or nil)
else
vim.env.OPENAI_API_KEY = json["access_token"]
if json["expires_in"] then
vim.env.OPENAI_API_KEY_EXPIRES = json["expires_in"]
end
end
return json["access_token"], json["expires_in"] or nil
end
end),
})
:start()
end

require("chatgpt").setup({
api_key_cmd = 'get_api_key'
})
```

### limited time key support
When "api_key_cmd" is run to obtain the key, it could optionally provide the key validity time.
ChatGPT plugin will automatically re-request "api_key_cmd" when time exceeds the validity time
when the next time ChatGPT is invoked. See the lua function section above for example.

## Usage

Plugin exposes following commands:
Expand Down
215 changes: 134 additions & 81 deletions lua/chatgpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,125 @@ local logger = require("chatgpt.common.logger")

local Api = {}

local key_expiry_timestamp = nil

local function updateAuthenticationKey(key, timeout_in_secs)
if not key then
logger.warn("OPENAI_API_KEY callback is nil")
return
end

if timeout_in_secs then
key_expiry_timestamp = os.time() + timeout_in_secs
end

Api.OPENAI_API_KEY = key
if Api["OPENAI_API_TYPE"] == "azure" then
Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OPENAI_API_KEY
else
Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OPENAI_API_KEY
end
end

local splitCommandIntoTable = function(command)
local cmd = {}
for word in command:gmatch("%S+") do
table.insert(cmd, word)
end
return cmd
end

local function loadConfigFromCommand(command, optionName, callback, defaultValue)
if (type(command) == "function") then
return command(callback) -- or callback(defaultValue)
else
local cmd = splitCommandIntoTable(command)
job
:new({
command = cmd[1],
args = vim.list_slice(cmd, 2, #cmd),
on_exit = function(j, exit_code)
if exit_code ~= 0 then
logger.warn("Config '" .. optionName .. "' did not return a value when executed")
return
end
local value = j:result()[1]:gsub("%s+$", "")
if value ~= nil and value ~= "" then
callback(value)
elseif defaultValue ~= nil and defaultValue ~= "" then
callback(defaultValue)
end
end,
})
:start()
return
end
end

local function loadConfigFromEnv(envName, configName, callback)
local variable = os.getenv(envName)
if not variable then
return
end
local value = variable:gsub("%s+$", "")
Api[configName] = value
if callback then
callback(value)
end
end

local function loadOptionalConfig(envName, configName, optionName, callback, defaultValue)
loadConfigFromEnv(envName, configName)
if Api[configName] then
callback(Api[configName])
elseif Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then
loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue)
else
callback(defaultValue)
end
end

local function loadRequiredConfig(envName, configName, optionName, callback, defaultValue)
loadConfigFromEnv(envName, configName, callback)
if not Api[configName] then
if Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then
loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue)
else
logger.warn(configName .. " variable not set")
return
end
end
end

-- Check if the key is valid and start the job
local function startJobWithKeyValidation(cb)
if not Api["OPENAI_API_KEY"] or (key_expiry_timestamp and os.time() > key_expiry_timestamp) then
if key_expiry_timestamp then
logger.info("startJobWithKeyValidation: Key expired at " .. key_expiry_timestamp)
end
Api["OPENAI_API_KEY"] = nil
loadRequiredConfig("OPENAI_API_KEY", "OPENAI_API_KEY", "api_key_cmd", vim.schedule_wrap(function(value, timeout)
updateAuthenticationKey(value, timeout)
cb()
end))
else
cb()
end
return 0
end

function Api.completions(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.openai_params)
Api.make_call(Api.COMPLETIONS_URL, params, cb)
end

function Api.chat_completions(custom_params, cb, should_stop)
local function doChatCompletions(custom_params, cb, should_stop)
local params = vim.tbl_extend("keep", custom_params, Config.options.openai_params)
local stream = params.stream or false
if stream then
local raw_chunks = ""
local state = "START"
local prev_chunk -- store incomplete line from previous chunk

cb = vim.schedule_wrap(cb)

Expand Down Expand Up @@ -42,6 +150,10 @@ function Api.chat_completions(custom_params, cb, should_stop)
"curl",
args,
function(chunk)
if prev_chunk ~= nil then
chunk = prev_chunk .. chunk
prev_chunk = nil
end
local ok, json = pcall(vim.json.decode, chunk)
if ok and json ~= nil then
if json.error ~= nil then
Expand Down Expand Up @@ -72,6 +184,8 @@ function Api.chat_completions(custom_params, cb, should_stop)
raw_chunks = raw_chunks .. json.choices[1].delta.content
state = "CONTINUE"
end
else
prev_chunk = line
end
end
end
Expand All @@ -89,6 +203,12 @@ function Api.chat_completions(custom_params, cb, should_stop)
end
end

function Api.chat_completions(custom_params, cb, should_stop)
startJobWithKeyValidation(function()
doChatCompletions(custom_params, cb, should_stop)
end)
end

function Api.edits(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.openai_edit_params)
if params.model == "text-davinci-edit-001" or params.model == "code-davinci-edit-001" then
Expand Down Expand Up @@ -127,15 +247,17 @@ function Api.make_call(url, params, cb)
end
end

Api.job = job
:new({
command = "curl",
args = args,
on_exit = vim.schedule_wrap(function(response, exit_code)
Api.handle_response(response, exit_code, cb)
end),
})
:start()
startJobWithKeyValidation(function()
Api.job = job
:new({
command = "curl",
args = args,
on_exit = vim.schedule_wrap(function(response, exit_code)
Api.handle_response(response, exit_code, cb)
end),
})
:start()
end)
end

Api.handle_response = vim.schedule_wrap(function(response, exit_code, cb)
Expand Down Expand Up @@ -183,71 +305,6 @@ function Api.close()
end
end

local splitCommandIntoTable = function(command)
local cmd = {}
for word in command:gmatch("%S+") do
table.insert(cmd, word)
end
return cmd
end

local function loadConfigFromCommand(command, optionName, callback, defaultValue)
local cmd = splitCommandIntoTable(command)
job
:new({
command = cmd[1],
args = vim.list_slice(cmd, 2, #cmd),
on_exit = function(j, exit_code)
if exit_code ~= 0 then
logger.warn("Config '" .. optionName .. "' did not return a value when executed")
return
end
local value = j:result()[1]:gsub("%s+$", "")
if value ~= nil and value ~= "" then
callback(value)
elseif defaultValue ~= nil and defaultValue ~= "" then
callback(defaultValue)
end
end,
})
:start()
end

local function loadConfigFromEnv(envName, configName, callback)
local variable = os.getenv(envName)
if not variable then
return
end
local value = variable:gsub("%s+$", "")
Api[configName] = value
if callback then
callback(value)
end
end

local function loadOptionalConfig(envName, configName, optionName, callback, defaultValue)
loadConfigFromEnv(envName, configName)
if Api[configName] then
callback(Api[configName])
elseif Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then
loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue)
else
callback(defaultValue)
end
end

local function loadRequiredConfig(envName, configName, optionName, callback, defaultValue)
loadConfigFromEnv(envName, configName, callback)
if not Api[configName] then
if Config.options[optionName] ~= nil and Config.options[optionName] ~= "" then
loadConfigFromCommand(Config.options[optionName], optionName, callback, defaultValue)
else
logger.warn(configName .. " variable not set")
return
end
end
end

local function loadAzureConfigs()
loadRequiredConfig("OPENAI_API_BASE", "OPENAI_API_BASE", "azure_api_base_cmd", function(base)
Api.OPENAI_API_BASE = base
Expand Down Expand Up @@ -301,16 +358,12 @@ function Api.setup()
Api.EDITS_URL = ensureUrlProtocol(Api.OPENAI_API_HOST .. "/v1/edits")
end, "api.openai.com")

loadRequiredConfig("OPENAI_API_KEY", "OPENAI_API_KEY", "api_key_cmd", function(key)
Api.OPENAI_API_KEY = key

loadRequiredConfig("OPENAI_API_KEY", "OPENAI_API_KEY", "api_key_cmd", function(key, timeout)
loadOptionalConfig("OPENAI_API_TYPE", "OPENAI_API_TYPE", "api_type_cmd", function(type)
if type == "azure" then
loadAzureConfigs()
Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OPENAI_API_KEY
else
Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OPENAI_API_KEY
end
updateAuthenticationKey(key, timeout)
end, "")
end)
end
Expand Down
Loading