Skip to content

Commit

Permalink
Add support for Anthropic API
Browse files Browse the repository at this point in the history
  • Loading branch information
stevegeek committed Jun 19, 2024
1 parent 07831b6 commit 4250955
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 84 deletions.
35 changes: 18 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@

__The goal for AIRefactor is to use LLMs to apply repetitive refactoring tasks to code.__

First the human decides what refactoring is needed and builds up a prompt to describe the task, or uses one of AIRefactors provided prompts.
## The workflow

AIRefactor then helps to apply the refactoring to one or more files.
1) the human decides what refactoring is needed
2) the human selects an existing built-in refactoring command, and/or builds up a prompt to describe the task
3) the human selects some source files to act as context (eg examples of the code post-refactor, or related classes etc)
4) the human runs the tool with the command, source files and context files
5) the AI generates the refactored code and outputs it either to a file or stdout.
6) In some cases, the tool can then check the generated code by running tests and comparing test outputs.

In some cases, the tool can then check the generated code by running tests and comparing test outputs.
AIRefactor can apply the refactoring to multiple files, allowing batch processing.

#### Notes

AI Refactor is an experimental tool and under active development as I explore the idea myself. It may not work as expected, or
change in ways that break existing functionality.

The focus of the tool is work with the Ruby programming language ecosystem, but it can be used with any language.
The focus of the tool is work with the **Ruby programming language ecosystem**, but it can be used with any language.

AI Refactor currently uses [OpenAI's ChatGPT](https://platform.openai.com/).
AI Refactor currently uses [OpenAI's ChatGPT](https://platform.openai.com/) or [Anthropic Claude](https://docs.anthropic.com/en/docs/about-claude/models) to generate code.

## Examples

Expand Down Expand Up @@ -64,7 +69,9 @@ Use a pre-built prompt:

### User supplied prompts, eg `custom`, `ruby/write_ruby` and `ruby/refactor_ruby`

Applies the refactor specified by prompting the AI with the user supplied prompt. You must supply a prompt file with the `-p` option.
You can use these commands in conjunction with a user supplied prompt.

You must supply a prompt file with the `-p` option.

The output is written to `stdout`, or to a file with the `--output` option.

Expand Down Expand Up @@ -178,7 +185,7 @@ output_file_path: output file or directory
output_template_path: output file template (see docs)
prompt_file_path: path
prompt: |
A custom prompt to send to ChatGPT if the command needs it (otherwise read from file)
A custom prompt to send to AI if the command needs it (otherwise read from file)
context_file_paths:
- file1.rb
- file2.rb
Expand All @@ -194,10 +201,10 @@ context_text: |
Some extra info to prepend to the prompt
diff: true/false (default false)
ai_max_attempts: max times to generate more if AI does not complete generating (default 3)
ai_model: ChatGPT model name (default gpt-4-turbo)
ai_temperature: ChatGPT temperature (default 0.7)
ai_max_tokens: ChatGPT max tokens (default 1500)
ai_timeout: ChatGPT timeout (default 60)
ai_model: AI model name, OpenAI GPT or Anthropic Claude (default gpt-4-turbo)
ai_temperature: AI temperature (default 0.7)
ai_max_tokens: AI max tokens (default 1500)
ai_timeout: AI timeout (default 60)
overwrite: y/n/a (default a)
verbose: true/false (default false)
debug: true/false (default false)
Expand Down Expand Up @@ -261,12 +268,6 @@ This file provides default CLI switches to add to any `ai_refactor` command.

The tool keeps a history of commands run in the `.ai_refactor_history` file in the current working directory.

## Note on performance and ChatGPT version

_The quality of results depend very much on the version of ChatGPT being used._

I have tested with both 3.5 and 4 and see **significantly** better performance with version 4.

## Development

After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
Expand Down
3 changes: 2 additions & 1 deletion ai_refactor.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Gem::Specification.new do |spec|
spec.description = "Use OpenAI's ChatGPT to automate converting Rails RSpec tests to minitest (ActiveSupport::TestCase)."
spec.homepage = "https://github.com/stevegeek/ai_refactor"
spec.license = "MIT"
spec.required_ruby_version = ">= 2.7.0"
spec.required_ruby_version = ">= 3.3.0"

spec.metadata["homepage_uri"] = spec.homepage
spec.metadata["source_code_uri"] = "https://github.com/stevegeek/ai_refactor"
Expand All @@ -32,5 +32,6 @@ Gem::Specification.new do |spec|
spec.add_dependency "colorize", "< 2.0"
spec.add_dependency "open3", "< 2.0"
spec.add_dependency "ruby-openai", ">= 3.4.0", "< 6.0"
spec.add_dependency "anthropic", ">= 0.1.0", "< 1.0"
spec.add_dependency "zeitwerk", "~> 2.6"
end
2 changes: 1 addition & 1 deletion examples/ex1_convert_a_rspec_test_to_minitest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ input_file_paths:
# We need to add context here as otherwise to tell the AI to require our local test_helper.rb file so that we can run the tests after
context_text: "In the output test use `require_relative '../test_helper'` to include 'test_helper'."
# By default, ai_refactor runs "bundle exec rails test" but this isn't going to work here as we are not actually in a Rails app context in the examples
minitest_run_command: ruby __FILE__
minitest_run_command: bundle exec ruby __FILE__
output_file_path: examples/outputs/ex1_input_test.rb
15 changes: 8 additions & 7 deletions exe/ai_refactor
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
require "optparse"
require "colorize"
require "openai"
require "anthropic"
require "shellwords"
require_relative "../lib/ai_refactor"

Expand Down Expand Up @@ -37,35 +38,35 @@ option_parser = OptionParser.new do |parser|
run_config.context_text = c
end

parser.on("-r", "--review-prompt", "Show the prompt that will be sent to ChatGPT but do not actually call ChatGPT or make changes to files.") do
parser.on("-r", "--review-prompt", "Show the prompt that will be sent to the AI but do not actually call the AI or make changes to files.") do
run_config.review_prompt = true
end

parser.on("-p", "--prompt PROMPT_FILE", String, "Specify path to a text file that contains the ChatGPT 'system' prompt.") do |f|
parser.on("-p", "--prompt PROMPT_FILE", String, "Specify path to a text file that contains the AI 'system' prompt.") do |f|
run_config.prompt_file_path = f
end

parser.on("-f", "--diffs", "Request AI generate diffs of changes rather than writing out the whole file.") do
run_config.diff = true
end

parser.on("-C", "--continue [MAX_MESSAGES]", Integer, "If ChatGPT stops generating due to the maximum token count being reached, continue to generate more messages, until a stop condition or MAX_MESSAGES. MAX_MESSAGES defaults to 3") do |c|
parser.on("-C", "--continue [MAX_MESSAGES]", Integer, "If AI stops generating due to the maximum token count being reached, continue to generate more messages, until a stop condition or MAX_MESSAGES. MAX_MESSAGES defaults to 3") do |c|
run_config.ai_max_attempts = c
end

parser.on("-m", "--model MODEL_NAME", String, "Specify a ChatGPT model to use (default gpt-4-turbo).") do |m|
parser.on("-m", "--model MODEL_NAME", String, "Specify a AI model to use (default 'gpt-4-turbo'). OpenAI and Anthropic models supported (eg 'gpt-4o', 'claude-3-opus-20240229')") do |m|
run_config.ai_model = m
end

parser.on("--temperature TEMP", Float, "Specify the temperature parameter for ChatGPT (default 0.7).") do |p|
parser.on("--temperature TEMP", Float, "Specify the temperature parameter for generation (default 0.7).") do |p|
run_config.ai_temperature = p
end

parser.on("--max-tokens MAX_TOKENS", Integer, "Specify the max number of tokens of output ChatGPT can generate. Max will depend on the size of the prompt (default 1500)") do |m|
parser.on("--max-tokens MAX_TOKENS", Integer, "Specify the max number of tokens of output the AI can generate. Max will depend on the size of the prompt (default 1500)") do |m|
run_config.ai_max_tokens = m
end

parser.on("-t", "--timeout SECONDS", Integer, "Specify the max wait time for ChatGPT response.") do |m|
parser.on("-t", "--timeout SECONDS", Integer, "Specify the max wait time for an AI response.") do |m|
run_config.ai_timeout = m
end

Expand Down
1 change: 1 addition & 0 deletions lib/ai_refactor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
loader = Zeitwerk::Loader.for_gem
loader.inflector.inflect(
"ai_refactor" => "AIRefactor",
"ai_client" => "AIClient",
"rspec_runner" => "RSpecRunner"
)
loader.setup # ready!
Expand Down
86 changes: 86 additions & 0 deletions lib/ai_refactor/ai_client.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# frozen_string_literal: true

module AIRefactor
class AIClient
def initialize(platform: "openai", model: "gpt-4-turbo", temperature: 0.7, max_tokens: 1500, timeout: 60, verbose: false)
@platform = platform
@model = model
@temperature = temperature
@max_tokens = max_tokens
@timeout = timeout
@verbose = verbose
@client = configure
end

def generate!(messages)
finished_reason, content, response = case @platform
when "openai"
openai_parse_response(
@client.chat(
parameters: {
messages: messages,
model: @model,
temperature: @temperature,
max_tokens: @max_tokens
}
)
)
when "anthropic"
anthropic_parse_response(
@client.messages(
parameters: {
system: messages.find { |m| m[:role] == "system" }&.fetch(:content, nil),
messages: messages.select { |m| m[:role] != "system" },
model: @model,
max_tokens: @max_tokens
}
)
)
else
raise "Invalid platform: #{@platform}"
end
yield finished_reason, content, response
end

private

def configure
case @platform
when "openai"
::OpenAI::Client.new(
access_token: ENV.fetch("OPENAI_API_KEY"),
organization_id: ENV.fetch("OPENAI_ORGANIZATION_ID", nil),
request_timeout: @timeout,
log_errors: @verbose
)
when "anthropic"
::Anthropic::Client.new(
access_token: ENV.fetch("ANTHROPIC_API_KEY"),
request_timeout: @timeout
)
else
raise "Invalid platform: #{@platform}"
end
end

def openai_parse_response(response)
if response["error"]
raise StandardError.new("OpenAI error: #{response["error"]["type"]}: #{response["error"]["message"]} (#{response["error"]["code"]})")
end

content = response.dig("choices", 0, "message", "content")
finished_reason = response.dig("choices", 0, "finish_reason")
[finished_reason, content, response]
end

def anthropic_parse_response(response)
if response["error"]
raise StandardError.new("Anthropic error: #{response["error"]["type"]}: #{response["error"]["message"]}")
end

content = response.dig("content", 0, "text")
finished_reason = response["stop_reason"] == "max_tokens" ? "length" : response["stop_reason"]
[finished_reason, content, response]
end
end
end
21 changes: 13 additions & 8 deletions lib/ai_refactor/cli.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def inputs
configuration.input_file_paths
end

def ai_client
@ai_client ||= AIRefactor::AIClient.new(
platform: configuration.ai_platform,
model: configuration.ai_model,
temperature: configuration.ai_temperature,
max_tokens: configuration.ai_max_tokens,
timeout: configuration.ai_timeout,
verbose: configuration.verbose
)
end

def valid?
return false unless refactorer
inputs_valid = refactorer.takes_input_files? ? !(inputs.nil? || inputs.empty?) : true
Expand All @@ -72,12 +83,6 @@ def valid?
def run
return false unless valid?

OpenAI.configure do |config|
config.access_token = ENV.fetch("OPENAI_API_KEY")
config.organization_id = ENV.fetch("OPENAI_ORGANIZATION_ID", nil)
config.request_timeout = configuration.ai_timeout || 240
end

if refactorer.takes_input_files?
expanded_inputs = inputs.map do |path|
File.exist?(path) ? path : Dir.glob(path)
Expand All @@ -92,7 +97,7 @@ def run
return_values = expanded_inputs.map do |file|
logger.info "Processing #{file}..."

refactor = refactorer.new(file, configuration, logger)
refactor = refactorer.new(ai_client, file, configuration, logger)
refactor_returned = refactor.run
failed = refactor_returned == false
if failed
Expand All @@ -118,7 +123,7 @@ def run
name = refactorer.refactor_name
logger.info "AI Refactor - #{name} refactor\n"
logger.info "====================\n"
refactor = refactorer.new(nil, configuration, logger)
refactor = refactorer.new(ai_client, nil, configuration, logger)
refactor_returned = refactor.run
failed = refactor_returned == false
if failed
Expand Down
40 changes: 13 additions & 27 deletions lib/ai_refactor/file_processor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,21 @@ def generate_next_message(messages, options, attempts_left)
logger.debug "Options: #{options.inspect}"
logger.debug "Messages: #{messages.inspect}"

response = @ai_client.chat(
parameters: {
model: options[:ai_model] || "gpt-4-turbo",
messages: messages,
temperature: options[:ai_temperature] || 0.7,
max_tokens: options[:ai_max_tokens] || 1500
}
)

if response["error"]
raise StandardError.new("OpenAI error: #{response["error"]["type"]}: #{response["error"]["message"]} (#{response["error"]["code"]})")
end

content = response.dig("choices", 0, "message", "content")
finished_reason = response.dig("choices", 0, "finish_reason")

if finished_reason == "length" && attempts_left > 0
generate_next_message(messages + [
{role: "assistant", content: content},
{role: "user", content: "Continue"}
], options, attempts_left - 1)
else
previous_messages = messages.filter { |m| m[:role] == "assistant" }.map { |m| m[:content] }.join
content = if previous_messages.length > 0
content ? previous_messages + content : previous_messages
@ai_client.generate!(messages) do |finished_reason, content, response|
if finished_reason == "length" && attempts_left > 0
generate_next_message(messages + [
{role: "assistant", content: content},
{role: "user", content: "Continue"}
], options, attempts_left - 1)
else
content
previous_messages = messages.filter { |m| m[:role] == "assistant" }.map { |m| m[:content] }.join
content = if previous_messages.length > 0
content ? previous_messages + content : previous_messages
else
content
end
[content, finished_reason, response["usage"]]
end
[content, finished_reason, response["usage"]]
end
end

Expand Down
14 changes: 7 additions & 7 deletions lib/ai_refactor/refactors/base_refactor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ def self.takes_input_files?
true
end

attr_reader :input_file, :options, :logger
attr_reader :ai_client, :input_file, :options, :logger
attr_accessor :input_content
attr_writer :failed_message

def initialize(input_file, options, logger)
def initialize(ai_client, input_file, options, logger)
@ai_client = ai_client
@input_file = input_file
@options = options
@logger = logger
Expand Down Expand Up @@ -79,8 +80,11 @@ def process!(strip_ticks: true)
output_content
rescue => e
logger.error "Request to AI failed: #{e.message}"
if e.respond_to?(:response) && e.response
logger.error "Response: #{e.response[:body]}"
end
logger.warn "Skipping #{input_file}..."
self.failed_message = "Request to OpenAI failed"
self.failed_message = "Request to AI API failed"
raise e
end
end
Expand Down Expand Up @@ -175,10 +179,6 @@ def output_file_path_from_template
path
end

def ai_client
@ai_client ||= OpenAI::Client.new
end

def refactor_name
self.class.refactor_name
end
Expand Down
Loading

0 comments on commit 4250955

Please sign in to comment.