From 42509553c609d237a5c82861da39779d8fd84f6c Mon Sep 17 00:00:00 2001 From: Stephen Ierodiaconou Date: Wed, 19 Jun 2024 11:14:18 +0200 Subject: [PATCH] Add support for Anthropic API --- README.md | 35 ++++---- ai_refactor.gemspec | 3 +- .../ex1_convert_a_rspec_test_to_minitest.yml | 2 +- exe/ai_refactor | 15 ++-- lib/ai_refactor.rb | 1 + lib/ai_refactor/ai_client.rb | 86 +++++++++++++++++++ lib/ai_refactor/cli.rb | 21 +++-- lib/ai_refactor/file_processor.rb | 40 +++------ lib/ai_refactor/refactors/base_refactor.rb | 14 +-- lib/ai_refactor/run_configuration.rb | 49 +++++++---- test/lib/ai_refactor/context_test.rb | 2 +- 11 files changed, 184 insertions(+), 84 deletions(-) create mode 100644 lib/ai_refactor/ai_client.rb diff --git a/README.md b/README.md index b0712d9..2d2dda5 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,25 @@ __The goal for AIRefactor is to use LLMs to apply repetitive refactoring tasks to code.__ -First the human decides what refactoring is needed and builds up a prompt to describe the task, or uses one of AIRefactors provided prompts. +## The workflow -AIRefactor then helps to apply the refactoring to one or more files. +1) the human decides what refactoring is needed +2) the human selects an existing built-in refactoring command, and/or builds up a prompt to describe the task +3) the human selects some source files to act as context (eg examples of the code post-refactor, or related classes etc) +4) the human runs the tool with the command, source files and context files +5) the AI generates the refactored code and outputs it either to a file or stdout. +6) In some cases, the tool can then check the generated code by running tests and comparing test outputs. -In some cases, the tool can then check the generated code by running tests and comparing test outputs. +AIRefactor can apply the refactoring to multiple files, allowing batch processing. #### Notes AI Refactor is an experimental tool and under active development as I explore the idea myself. It may not work as expected, or change in ways that break existing functionality. -The focus of the tool is work with the Ruby programming language ecosystem, but it can be used with any language. +The focus of the tool is work with the **Ruby programming language ecosystem**, but it can be used with any language. -AI Refactor currently uses [OpenAI's ChatGPT](https://platform.openai.com/). +AI Refactor currently uses [OpenAI's ChatGPT](https://platform.openai.com/) or [Anthropic Claude](https://docs.anthropic.com/en/docs/about-claude/models) to generate code. ## Examples @@ -64,7 +69,9 @@ Use a pre-built prompt: ### User supplied prompts, eg `custom`, `ruby/write_ruby` and `ruby/refactor_ruby` -Applies the refactor specified by prompting the AI with the user supplied prompt. You must supply a prompt file with the `-p` option. +You can use these commands in conjunction with a user supplied prompt. + +You must supply a prompt file with the `-p` option. The output is written to `stdout`, or to a file with the `--output` option. @@ -178,7 +185,7 @@ output_file_path: output file or directory output_template_path: output file template (see docs) prompt_file_path: path prompt: | - A custom prompt to send to ChatGPT if the command needs it (otherwise read from file) + A custom prompt to send to AI if the command needs it (otherwise read from file) context_file_paths: - file1.rb - file2.rb @@ -194,10 +201,10 @@ context_text: | Some extra info to prepend to the prompt diff: true/false (default false) ai_max_attempts: max times to generate more if AI does not complete generating (default 3) -ai_model: ChatGPT model name (default gpt-4-turbo) -ai_temperature: ChatGPT temperature (default 0.7) -ai_max_tokens: ChatGPT max tokens (default 1500) -ai_timeout: ChatGPT timeout (default 60) +ai_model: AI model name, OpenAI GPT or Anthropic Claude (default gpt-4-turbo) +ai_temperature: AI temperature (default 0.7) +ai_max_tokens: AI max tokens (default 1500) +ai_timeout: AI timeout (default 60) overwrite: y/n/a (default a) verbose: true/false (default false) debug: true/false (default false) @@ -261,12 +268,6 @@ This file provides default CLI switches to add to any `ai_refactor` command. The tool keeps a history of commands run in the `.ai_refactor_history` file in the current working directory. -## Note on performance and ChatGPT version - -_The quality of results depend very much on the version of ChatGPT being used._ - -I have tested with both 3.5 and 4 and see **significantly** better performance with version 4. - ## Development After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake test` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment. diff --git a/ai_refactor.gemspec b/ai_refactor.gemspec index 141b2da..bdec014 100644 --- a/ai_refactor.gemspec +++ b/ai_refactor.gemspec @@ -12,7 +12,7 @@ Gem::Specification.new do |spec| spec.description = "Use OpenAI's ChatGPT to automate converting Rails RSpec tests to minitest (ActiveSupport::TestCase)." spec.homepage = "https://github.com/stevegeek/ai_refactor" spec.license = "MIT" - spec.required_ruby_version = ">= 2.7.0" + spec.required_ruby_version = ">= 3.3.0" spec.metadata["homepage_uri"] = spec.homepage spec.metadata["source_code_uri"] = "https://github.com/stevegeek/ai_refactor" @@ -32,5 +32,6 @@ Gem::Specification.new do |spec| spec.add_dependency "colorize", "< 2.0" spec.add_dependency "open3", "< 2.0" spec.add_dependency "ruby-openai", ">= 3.4.0", "< 6.0" + spec.add_dependency "anthropic", ">= 0.1.0", "< 1.0" spec.add_dependency "zeitwerk", "~> 2.6" end diff --git a/examples/ex1_convert_a_rspec_test_to_minitest.yml b/examples/ex1_convert_a_rspec_test_to_minitest.yml index bcaf06d..5f7539f 100644 --- a/examples/ex1_convert_a_rspec_test_to_minitest.yml +++ b/examples/ex1_convert_a_rspec_test_to_minitest.yml @@ -4,5 +4,5 @@ input_file_paths: # We need to add context here as otherwise to tell the AI to require our local test_helper.rb file so that we can run the tests after context_text: "In the output test use `require_relative '../test_helper'` to include 'test_helper'." # By default, ai_refactor runs "bundle exec rails test" but this isn't going to work here as we are not actually in a Rails app context in the examples -minitest_run_command: ruby __FILE__ +minitest_run_command: bundle exec ruby __FILE__ output_file_path: examples/outputs/ex1_input_test.rb diff --git a/exe/ai_refactor b/exe/ai_refactor index 5a29448..ed92cdd 100755 --- a/exe/ai_refactor +++ b/exe/ai_refactor @@ -3,6 +3,7 @@ require "optparse" require "colorize" require "openai" +require "anthropic" require "shellwords" require_relative "../lib/ai_refactor" @@ -37,11 +38,11 @@ option_parser = OptionParser.new do |parser| run_config.context_text = c end - parser.on("-r", "--review-prompt", "Show the prompt that will be sent to ChatGPT but do not actually call ChatGPT or make changes to files.") do + parser.on("-r", "--review-prompt", "Show the prompt that will be sent to the AI but do not actually call the AI or make changes to files.") do run_config.review_prompt = true end - parser.on("-p", "--prompt PROMPT_FILE", String, "Specify path to a text file that contains the ChatGPT 'system' prompt.") do |f| + parser.on("-p", "--prompt PROMPT_FILE", String, "Specify path to a text file that contains the AI 'system' prompt.") do |f| run_config.prompt_file_path = f end @@ -49,23 +50,23 @@ option_parser = OptionParser.new do |parser| run_config.diff = true end - parser.on("-C", "--continue [MAX_MESSAGES]", Integer, "If ChatGPT stops generating due to the maximum token count being reached, continue to generate more messages, until a stop condition or MAX_MESSAGES. MAX_MESSAGES defaults to 3") do |c| + parser.on("-C", "--continue [MAX_MESSAGES]", Integer, "If AI stops generating due to the maximum token count being reached, continue to generate more messages, until a stop condition or MAX_MESSAGES. MAX_MESSAGES defaults to 3") do |c| run_config.ai_max_attempts = c end - parser.on("-m", "--model MODEL_NAME", String, "Specify a ChatGPT model to use (default gpt-4-turbo).") do |m| + parser.on("-m", "--model MODEL_NAME", String, "Specify a AI model to use (default 'gpt-4-turbo'). OpenAI and Anthropic models supported (eg 'gpt-4o', 'claude-3-opus-20240229')") do |m| run_config.ai_model = m end - parser.on("--temperature TEMP", Float, "Specify the temperature parameter for ChatGPT (default 0.7).") do |p| + parser.on("--temperature TEMP", Float, "Specify the temperature parameter for generation (default 0.7).") do |p| run_config.ai_temperature = p end - parser.on("--max-tokens MAX_TOKENS", Integer, "Specify the max number of tokens of output ChatGPT can generate. Max will depend on the size of the prompt (default 1500)") do |m| + parser.on("--max-tokens MAX_TOKENS", Integer, "Specify the max number of tokens of output the AI can generate. Max will depend on the size of the prompt (default 1500)") do |m| run_config.ai_max_tokens = m end - parser.on("-t", "--timeout SECONDS", Integer, "Specify the max wait time for ChatGPT response.") do |m| + parser.on("-t", "--timeout SECONDS", Integer, "Specify the max wait time for an AI response.") do |m| run_config.ai_timeout = m end diff --git a/lib/ai_refactor.rb b/lib/ai_refactor.rb index 61305cd..94aaa6c 100644 --- a/lib/ai_refactor.rb +++ b/lib/ai_refactor.rb @@ -4,6 +4,7 @@ loader = Zeitwerk::Loader.for_gem loader.inflector.inflect( "ai_refactor" => "AIRefactor", + "ai_client" => "AIClient", "rspec_runner" => "RSpecRunner" ) loader.setup # ready! diff --git a/lib/ai_refactor/ai_client.rb b/lib/ai_refactor/ai_client.rb new file mode 100644 index 0000000..6a54251 --- /dev/null +++ b/lib/ai_refactor/ai_client.rb @@ -0,0 +1,86 @@ +# frozen_string_literal: true + +module AIRefactor + class AIClient + def initialize(platform: "openai", model: "gpt-4-turbo", temperature: 0.7, max_tokens: 1500, timeout: 60, verbose: false) + @platform = platform + @model = model + @temperature = temperature + @max_tokens = max_tokens + @timeout = timeout + @verbose = verbose + @client = configure + end + + def generate!(messages) + finished_reason, content, response = case @platform + when "openai" + openai_parse_response( + @client.chat( + parameters: { + messages: messages, + model: @model, + temperature: @temperature, + max_tokens: @max_tokens + } + ) + ) + when "anthropic" + anthropic_parse_response( + @client.messages( + parameters: { + system: messages.find { |m| m[:role] == "system" }&.fetch(:content, nil), + messages: messages.select { |m| m[:role] != "system" }, + model: @model, + max_tokens: @max_tokens + } + ) + ) + else + raise "Invalid platform: #{@platform}" + end + yield finished_reason, content, response + end + + private + + def configure + case @platform + when "openai" + ::OpenAI::Client.new( + access_token: ENV.fetch("OPENAI_API_KEY"), + organization_id: ENV.fetch("OPENAI_ORGANIZATION_ID", nil), + request_timeout: @timeout, + log_errors: @verbose + ) + when "anthropic" + ::Anthropic::Client.new( + access_token: ENV.fetch("ANTHROPIC_API_KEY"), + request_timeout: @timeout + ) + else + raise "Invalid platform: #{@platform}" + end + end + + def openai_parse_response(response) + if response["error"] + raise StandardError.new("OpenAI error: #{response["error"]["type"]}: #{response["error"]["message"]} (#{response["error"]["code"]})") + end + + content = response.dig("choices", 0, "message", "content") + finished_reason = response.dig("choices", 0, "finish_reason") + [finished_reason, content, response] + end + + def anthropic_parse_response(response) + if response["error"] + raise StandardError.new("Anthropic error: #{response["error"]["type"]}: #{response["error"]["message"]}") + end + + content = response.dig("content", 0, "text") + finished_reason = response["stop_reason"] == "max_tokens" ? "length" : response["stop_reason"] + [finished_reason, content, response] + end + end +end diff --git a/lib/ai_refactor/cli.rb b/lib/ai_refactor/cli.rb index 162b679..0e3699d 100644 --- a/lib/ai_refactor/cli.rb +++ b/lib/ai_refactor/cli.rb @@ -63,6 +63,17 @@ def inputs configuration.input_file_paths end + def ai_client + @ai_client ||= AIRefactor::AIClient.new( + platform: configuration.ai_platform, + model: configuration.ai_model, + temperature: configuration.ai_temperature, + max_tokens: configuration.ai_max_tokens, + timeout: configuration.ai_timeout, + verbose: configuration.verbose + ) + end + def valid? return false unless refactorer inputs_valid = refactorer.takes_input_files? ? !(inputs.nil? || inputs.empty?) : true @@ -72,12 +83,6 @@ def valid? def run return false unless valid? - OpenAI.configure do |config| - config.access_token = ENV.fetch("OPENAI_API_KEY") - config.organization_id = ENV.fetch("OPENAI_ORGANIZATION_ID", nil) - config.request_timeout = configuration.ai_timeout || 240 - end - if refactorer.takes_input_files? expanded_inputs = inputs.map do |path| File.exist?(path) ? path : Dir.glob(path) @@ -92,7 +97,7 @@ def run return_values = expanded_inputs.map do |file| logger.info "Processing #{file}..." - refactor = refactorer.new(file, configuration, logger) + refactor = refactorer.new(ai_client, file, configuration, logger) refactor_returned = refactor.run failed = refactor_returned == false if failed @@ -118,7 +123,7 @@ def run name = refactorer.refactor_name logger.info "AI Refactor - #{name} refactor\n" logger.info "====================\n" - refactor = refactorer.new(nil, configuration, logger) + refactor = refactorer.new(ai_client, nil, configuration, logger) refactor_returned = refactor.run failed = refactor_returned == false if failed diff --git a/lib/ai_refactor/file_processor.rb b/lib/ai_refactor/file_processor.rb index cc496df..2694c8f 100644 --- a/lib/ai_refactor/file_processor.rb +++ b/lib/ai_refactor/file_processor.rb @@ -60,35 +60,21 @@ def generate_next_message(messages, options, attempts_left) logger.debug "Options: #{options.inspect}" logger.debug "Messages: #{messages.inspect}" - response = @ai_client.chat( - parameters: { - model: options[:ai_model] || "gpt-4-turbo", - messages: messages, - temperature: options[:ai_temperature] || 0.7, - max_tokens: options[:ai_max_tokens] || 1500 - } - ) - - if response["error"] - raise StandardError.new("OpenAI error: #{response["error"]["type"]}: #{response["error"]["message"]} (#{response["error"]["code"]})") - end - - content = response.dig("choices", 0, "message", "content") - finished_reason = response.dig("choices", 0, "finish_reason") - - if finished_reason == "length" && attempts_left > 0 - generate_next_message(messages + [ - {role: "assistant", content: content}, - {role: "user", content: "Continue"} - ], options, attempts_left - 1) - else - previous_messages = messages.filter { |m| m[:role] == "assistant" }.map { |m| m[:content] }.join - content = if previous_messages.length > 0 - content ? previous_messages + content : previous_messages + @ai_client.generate!(messages) do |finished_reason, content, response| + if finished_reason == "length" && attempts_left > 0 + generate_next_message(messages + [ + {role: "assistant", content: content}, + {role: "user", content: "Continue"} + ], options, attempts_left - 1) else - content + previous_messages = messages.filter { |m| m[:role] == "assistant" }.map { |m| m[:content] }.join + content = if previous_messages.length > 0 + content ? previous_messages + content : previous_messages + else + content + end + [content, finished_reason, response["usage"]] end - [content, finished_reason, response["usage"]] end end diff --git a/lib/ai_refactor/refactors/base_refactor.rb b/lib/ai_refactor/refactors/base_refactor.rb index 3fbc9f2..efdfb21 100644 --- a/lib/ai_refactor/refactors/base_refactor.rb +++ b/lib/ai_refactor/refactors/base_refactor.rb @@ -17,11 +17,12 @@ def self.takes_input_files? true end - attr_reader :input_file, :options, :logger + attr_reader :ai_client, :input_file, :options, :logger attr_accessor :input_content attr_writer :failed_message - def initialize(input_file, options, logger) + def initialize(ai_client, input_file, options, logger) + @ai_client = ai_client @input_file = input_file @options = options @logger = logger @@ -79,8 +80,11 @@ def process!(strip_ticks: true) output_content rescue => e logger.error "Request to AI failed: #{e.message}" + if e.respond_to?(:response) && e.response + logger.error "Response: #{e.response[:body]}" + end logger.warn "Skipping #{input_file}..." - self.failed_message = "Request to OpenAI failed" + self.failed_message = "Request to AI API failed" raise e end end @@ -175,10 +179,6 @@ def output_file_path_from_template path end - def ai_client - @ai_client ||= OpenAI::Client.new - end - def refactor_name self.class.refactor_name end diff --git a/lib/ai_refactor/run_configuration.rb b/lib/ai_refactor/run_configuration.rb index 7ef7bd7..d5659be 100644 --- a/lib/ai_refactor/run_configuration.rb +++ b/lib/ai_refactor/run_configuration.rb @@ -18,11 +18,6 @@ def self.add_new_option(key) :review_prompt, :prompt, :prompt_file_path, - :ai_max_attempts, - :ai_model, - :ai_temperature, - :ai_max_tokens, - :ai_timeout, :overwrite, :diff, :verbose, @@ -97,30 +92,54 @@ def minitest_run_command attr_writer :rspec_run_command attr_writer :minitest_run_command + def ai_max_attempts + @ai_max_attempts || 3 + end + def ai_max_attempts=(value) - @ai_max_attempts = value || 3 + @ai_max_attempts = value + end + + def ai_model + @ai_model || "gpt-4-turbo" end def ai_model=(value) - @ai_model = value || "gpt-4-turbo" + @ai_model = value end - def ai_temperature=(value) - @ai_temperature = value || 0.7 + def ai_platform + if ai_model&.start_with?("claude") + "anthropic" + else + "openai" + end end - def ai_max_tokens=(value) - @ai_max_tokens = value || 1500 + def ai_temperature + @ai_temperature || 0.7 end - def ai_timeout=(value) - @ai_timeout = value || 60 + attr_writer :ai_temperature + + def ai_max_tokens + @ai_max_tokens || 1500 end - def overwrite=(value) - @overwrite = value || "a" + attr_writer :ai_max_tokens + + def ai_timeout + @ai_timeout || 60 end + attr_writer :ai_timeout + + def overwrite + @overwrite || "a" + end + + attr_writer :overwrite + attr_writer :diff attr_writer :verbose diff --git a/test/lib/ai_refactor/context_test.rb b/test/lib/ai_refactor/context_test.rb index ba71c35..3b6b9df 100644 --- a/test/lib/ai_refactor/context_test.rb +++ b/test/lib/ai_refactor/context_test.rb @@ -33,7 +33,7 @@ def test_prepare_context_with_mixed_files @context = Context.new(files: ["file1"], text: "Hi!", logger: @logger) File.stub :exist?, lambda { |file| file == "file1" } do File.stub :read, "content" do - expected_output = "Also note: Hi!\n\nHere is some related files:\n\n#---\n# File 'file1':\n\n```content```\n" + expected_output = "\nHere is some related files:\n\n#---\n# File 'file1':\n\n```content```\n\n\nHi!\n" assert_equal expected_output, @context.prepare_context @logger.verify end