Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 96 additions & 35 deletions plugin/claude.vim
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ if !exists('g:claude_bedrock_model_id')
let g:claude_bedrock_model_id = 'anthropic.claude-3-5-sonnet-20241022-v2:0'
endif

if !exists('g:claude_token_usage_categories')
let g:claude_token_usage_categories = [
\["Input Tokens", "input_tokens", 3.0/1.0E6],
\["Output Tokens", "output_tokens", 15.0/1.0E6],
\["Cache Create Input Tokens", "cache_creation_input_tokens", 3.75/1.0E6 ],
\["Cache Read Input Tokens", "cache_read_input_tokens", 0.3/1.0E6 ]
\]
endif
if !exists('g:claude_aws_profile')
let g:claude_aws_profile = ''
endif
Expand Down Expand Up @@ -91,18 +99,76 @@ endif
" Claude API
" ============================================================================

function! s:ClaudeQueryInternal(messages, system_prompt, tools, stream_callback, final_callback)
" {"role": "user", "content":"Hi"} -> {"role":"user", "content":[{"type":"text", "text": "hi"}]
function! s:InflateMessageContent(msg)
let result=copy(a:msg)
let content = get(result, "content",[])
if type(content) == v:t_string
let result.content = [{"type": "text", "text": content}]
endif
return result
endfunction

" returns modification timestamp or zero if no undo information exists
function! s:GetBufTimestamp(bufnr)
return get(undotree(a:bufnr).entries,-1,{"time":0}).time
endfunction

function! s:DocumentContext(name,content)
return join([
\"<buffer>",
\"<name>" . a:name . "</name>",
\"<content>",
\a:content,
\"</content>",
\"</buffer>"
\], "\n")
endfunction

function! s:ClaudeQueryInternal(messages, buffers, system_prompt, tools, stream_callback, final_callback) abort
" Prepare the API request
let l:data = {}
let l:headers = []
let l:url = ''

" Construct messages with content objects
let messages = map(a:messages, "s:InflateMessageContent(v:val)")

let quiet_buffers = []
let busy_buffers = []
let cutoff_timestamp = localtime() - 120
for b in a:buffers
let context = s:DocumentContext(b.name, b.contents)
if b.lastmodified == 0 || b.lastmodified > cutoff_timestamp
call add(quiet_buffers,context)
else
call add(busy_buffers,context)
endif
endfor

let context_messages = []
if !empty(quiet_buffers)
let msg = {"type" : "text", "text": join(["<documents>"] + quiet_buffers + ["</documents>"],"\n"), "cache_control": {"type": "ephemeral"}}
call add(context_messages, msg)
endif

if !empty(busy_buffers)
let msg = {"type" : "text", "text": join(["<documents>"] + busy_buffers + ["</documents>"],"\n"), "cache_control": {"type": "ephemeral"}}
call add(context_messages, msg)
endif

" And prepend content objects for the buffers
let messages[0].content = context_messages + messages[0].content
let messages[-1].content[-1].cache_control = {"type": "ephemeral"}


"FIXME - test bedrock using new context
if g:claude_use_bedrock
let l:python_script = s:plugin_dir . '/claude_bedrock_helper.py'
let l:cmd = ['python3', l:python_script,
\ '--region', g:claude_bedrock_region,
\ '--model-id', g:claude_bedrock_model_id,
\ '--messages', json_encode(a:messages),
\ '--messages', json_encode(messages),
\ '--system-prompt', a:system_prompt]

if !empty(g:claude_aws_profile)
Expand All @@ -117,7 +183,7 @@ function! s:ClaudeQueryInternal(messages, system_prompt, tools, stream_callback,
let l:data = {
\ 'model': g:claude_model,
\ 'max_tokens': 2048,
\ 'messages': a:messages,
\ 'messages': messages,
\ 'stream': v:true
\ }
if !empty(a:system_prompt)
Expand All @@ -129,6 +195,7 @@ function! s:ClaudeQueryInternal(messages, system_prompt, tools, stream_callback,
call extend(l:headers, ['-H', 'Content-Type: application/json'])
call extend(l:headers, ['-H', 'x-api-key: ' . g:claude_api_key])
call extend(l:headers, ['-H', 'anthropic-version: 2023-06-01'])
call extend(l:headers, ['-H', "anthropic-beta: prompt-caching-2024-07-31"])

" Convert data to JSON
let l:json_data = json_encode(l:data)
Expand All @@ -155,24 +222,16 @@ function! s:ClaudeQueryInternal(messages, system_prompt, tools, stream_callback,
return l:job
endfunction

function! s:DisplayTokenUsageAndCost(json_data)
let l:data = json_decode(a:json_data)
if has_key(l:data, 'usage')
let l:usage = l:data.usage
let l:input_tokens = exists('s:stored_input_tokens') ? s:stored_input_tokens : get(l:usage, 'input_tokens', 0)
let l:output_tokens = get(l:usage, 'output_tokens', 0)

let l:input_cost = (l:input_tokens / 1000000.0) * 3.0
let l:output_cost = (l:output_tokens / 1000000.0) * 15.0

echom printf("Token usage - Input: %d ($%.4f), Output: %d ($%.4f)", l:input_tokens, l:input_cost, l:output_tokens, l:output_cost)

if exists('s:stored_input_tokens')
unlet s:stored_input_tokens
endif
else
echom "Error: Invalid JSON data format"
endif
function! s:DisplayTokenUsageAndCost(input_usage, final_usage)
let msg = ["Token usage -"]
for cat in g:claude_token_usage_categories
let label = cat[0]
let key = cat[1]
let price = cat[2]
let usage = get(a:input_usage, key,0) + get(a:final_usage, key, 0)
call add(msg, printf("%s: %d ($%.4f)", label, usage, usage * price))
endfor
echom join(msg, " ")
endfunction

function! s:HandleStreamOutput(stream_callback, final_callback, channel, msg)
Expand Down Expand Up @@ -206,9 +265,13 @@ function! s:HandleStreamOutput(stream_callback, final_callback, channel, msg)
let l:delta = l:response.delta.text
call a:stream_callback(l:delta)
elseif l:response.type == 'message_start' && has_key(l:response, 'message') && has_key(l:response.message, 'usage')
let s:stored_input_tokens = get(l:response.message.usage, 'input_tokens', 0)
let s:stored_input_usage = copy(l:response.message.usage)
elseif l:response.type == 'message_delta' && has_key(l:response, 'usage')
call s:DisplayTokenUsageAndCost(l:json_str)
if !exists('s:stored_input_usage')
let s:stored_input_usage = {}
endif
call s:DisplayTokenUsageAndCost(s:stored_input_usage, l:response.usage)
unlet s:stored_input_usage
elseif l:response.type != 'message_stop' && l:response.type != 'message_start' && l:response.type != 'content_block_start' && l:response.type != 'ping'
call a:stream_callback('Unknown Claude protocol output: "' . l:line . "\"\n")
endif
Expand Down Expand Up @@ -580,12 +643,13 @@ function! s:ClaudeImplement(line1, line2, instruction) range
let l:winid = win_getid()

" Prepare the prompt for code implementation
let l:prompt = "<code>\n" . l:selected_code . "\n</code>\n\n"
let l:prompt .= join(g:claude_implement_prompt, "\n")
let l:code_context = [{"name": "selected_code", "contents": l:selected_code, "lastmodified": 0}]
let l:system_prompt = join(g:claude_implement_prompt, "\n")
let l:tools = []

" Query Claude
let l:messages = [{'role': 'user', 'content': a:instruction}]
call s:ClaudeQueryInternal(l:messages, l:prompt, [],
call s:ClaudeQueryInternal(l:messages, l:code_context, l:system_prompt, l:tools,
\ function('s:StreamingImplementResponse'),
\ function('s:FinalImplementResponse', [a:line1, a:line2, l:bufnr, l:bufname, l:winid, a:instruction]))
endfunction
Expand Down Expand Up @@ -906,7 +970,11 @@ function! s:GetBuffersContent()
if buflisted(bufnr) && bufname(bufnr) != 'Claude Chat' && !empty(win_findbuf(bufnr))
let l:bufname = bufname(bufnr)
let l:contents = join(getbufline(bufnr, 1, '$'), "\n")
call add(l:buffers, {'name': l:bufname, 'contents': l:contents})
call add(l:buffers, {
\ 'name': l:bufname,
\ 'lastmodified': s:GetBufTimestamp(bufnr),
\ 'contents': l:contents
\ })
endif
endfor
return l:buffers
Expand All @@ -925,17 +993,10 @@ function! s:SendChatMessage(prefix)
endif

let l:buffer_contents = s:GetBuffersContent()
let l:content_prompt = "# Contents of open buffers\n\n"
for buffer in l:buffer_contents
let l:content_prompt .= "Buffer: " . buffer.name . "\n"
let l:content_prompt .= "<content>\n" . buffer.contents . "</content>\n\n"
let l:content_prompt .= "============================\n\n"
endfor

call append('$', a:prefix . " ")
normal! G

let l:job = s:ClaudeQueryInternal(l:messages, l:content_prompt . l:system_prompt, g:claude_tools, function('s:StreamingChatResponse'), function('s:FinalChatResponse'))
let l:job = s:ClaudeQueryInternal(l:messages, l:buffer_contents, l:system_prompt, g:claude_tools, function('s:StreamingChatResponse'), function('s:FinalChatResponse'))

" Store the job ID or channel for potential cancellation
if has('nvim')
Expand Down