Skip to content

Commit 50259fb

Browse files
committed
Adding xAI provider for grok models
1 parent d8fb8f0 commit 50259fb

File tree

2 files changed

+371
-0
lines changed

2 files changed

+371
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# frozen_string_literal: true
2+
3+
begin
4+
gem "ruby-openai", ">= 8.1.0"
5+
require "openai"
6+
rescue LoadError
7+
raise LoadError, "The 'ruby-openai >= 8.1.0' gem is required for XAIProvider. Please add it to your Gemfile and run `bundle install`."
8+
end
9+
10+
require "active_agent/action_prompt/action"
11+
require_relative "base"
12+
require_relative "response"
13+
require_relative "stream_processing"
14+
require_relative "message_formatting"
15+
require_relative "tool_management"
16+
17+
module ActiveAgent
18+
module GenerationProvider
19+
# XAI (Grok) Generation Provider
20+
# Uses OpenAI-compatible API format with xAI's endpoint
21+
class XAIProvider < Base
22+
include StreamProcessing
23+
include MessageFormatting
24+
include ToolManagement
25+
26+
XAI_API_HOST = "https://api.x.ai"
27+
28+
def initialize(config)
29+
super
30+
# Support both api_key and access_token for backwards compatibility
31+
@access_token = config["api_key"] || config["access_token"] || ENV["XAI_API_KEY"] || ENV["GROK_API_KEY"]
32+
33+
unless @access_token
34+
raise ArgumentError, "XAI API key is required. Set it in config as 'api_key', 'access_token', or via XAI_API_KEY/GROK_API_KEY environment variable."
35+
end
36+
37+
# xAI uses OpenAI-compatible client with custom endpoint
38+
@client = OpenAI::Client.new(
39+
access_token: @access_token,
40+
uri_base: config["host"] || XAI_API_HOST,
41+
log_errors: Rails.env.development?
42+
)
43+
44+
# Default to grok-2-latest but allow configuration
45+
@model_name = config["model"] || "grok-2-latest"
46+
end
47+
48+
def generate(prompt)
49+
@prompt = prompt
50+
51+
with_error_handling do
52+
chat_prompt(parameters: prompt_parameters)
53+
end
54+
end
55+
56+
def embed(prompt)
57+
# xAI doesn't currently provide embedding models
58+
raise NotImplementedError, "xAI does not currently support embeddings. Use a different provider for embedding tasks."
59+
end
60+
61+
protected
62+
63+
# Override from StreamProcessing module - uses OpenAI format
64+
def process_stream_chunk(chunk, message, agent_stream)
65+
new_content = chunk.dig("choices", 0, "delta", "content")
66+
if new_content && !new_content.blank?
67+
message.generation_id = chunk.dig("id")
68+
message.content += new_content
69+
agent_stream&.call(message, new_content, false, prompt.action_name)
70+
elsif chunk.dig("choices", 0, "delta", "tool_calls") && chunk.dig("choices", 0, "delta", "role")
71+
message = handle_message(chunk.dig("choices", 0, "delta"))
72+
prompt.messages << message
73+
@response = ActiveAgent::GenerationProvider::Response.new(
74+
prompt:,
75+
message:,
76+
raw_response: chunk,
77+
raw_request: @streaming_request_params
78+
)
79+
end
80+
81+
if chunk.dig("choices", 0, "finish_reason")
82+
finalize_stream(message, agent_stream)
83+
end
84+
end
85+
86+
# Override from MessageFormatting module to handle image format (if xAI adds vision support)
87+
def format_image_content(message)
88+
[ {
89+
type: "image_url",
90+
image_url: { url: message.content }
91+
} ]
92+
end
93+
94+
private
95+
96+
# Override from ParameterBuilder to add xAI-specific parameters if needed
97+
def build_provider_parameters
98+
params = {}
99+
100+
# Add any xAI-specific parameters here
101+
# For now, xAI follows OpenAI's format closely
102+
103+
params
104+
end
105+
106+
def chat_response(response, request_params = nil)
107+
return @response if prompt.options[:stream]
108+
109+
message_json = response.dig("choices", 0, "message")
110+
message_json["id"] = response.dig("id") if message_json["id"].blank?
111+
message = handle_message(message_json)
112+
113+
update_context(prompt: prompt, message: message, response: response)
114+
115+
@response = ActiveAgent::GenerationProvider::Response.new(
116+
prompt: prompt,
117+
message: message,
118+
raw_response: response,
119+
raw_request: request_params
120+
)
121+
end
122+
123+
def handle_message(message_json)
124+
ActiveAgent::ActionPrompt::Message.new(
125+
generation_id: message_json["id"],
126+
content: message_json["content"],
127+
role: message_json["role"].intern,
128+
action_requested: message_json["finish_reason"] == "tool_calls",
129+
raw_actions: message_json["tool_calls"] || [],
130+
requested_actions: handle_actions(message_json["tool_calls"]),
131+
content_type: prompt.output_schema.present? ? "application/json" : "text/plain"
132+
)
133+
end
134+
135+
def chat_prompt(parameters: prompt_parameters)
136+
if prompt.options[:stream] || config["stream"]
137+
parameters[:stream] = provider_stream
138+
@streaming_request_params = parameters
139+
end
140+
chat_response(@client.chat(parameters: parameters), parameters)
141+
end
142+
end
143+
end
144+
end
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
require "test_helper"
2+
require "active_agent/generation_provider/x_ai_provider"
3+
4+
# Test for xAI Provider gem loading and configuration
5+
class XAIProviderTest < ActiveAgentTestCase
6+
# Test the gem load rescue block
7+
test "gem load rescue block provides correct error message" do
8+
# Since we can't easily simulate the gem not being available without complex mocking,
9+
# we'll test that the error message is correct by creating a minimal reproduction
10+
expected_message = "The 'ruby-openai >= 8.1.0' gem is required for XAIProvider. Please add it to your Gemfile and run `bundle install`."
11+
12+
# Verify the rescue block pattern exists in the source code
13+
provider_file_path = File.join(Rails.root, "../../lib/active_agent/generation_provider/x_ai_provider.rb")
14+
provider_source = File.read(provider_file_path)
15+
16+
assert_includes provider_source, "begin"
17+
assert_includes provider_source, 'gem "ruby-openai"'
18+
assert_includes provider_source, 'require "openai"'
19+
assert_includes provider_source, "rescue LoadError"
20+
assert_includes provider_source, expected_message
21+
22+
# Test the actual error by creating a minimal scenario
23+
test_code = <<~RUBY
24+
begin
25+
gem "nonexistent-openai-gem"
26+
require "nonexistent-openai-gem"
27+
rescue LoadError
28+
raise LoadError, "#{expected_message}"
29+
end
30+
RUBY
31+
32+
error = assert_raises(LoadError) do
33+
eval(test_code)
34+
end
35+
36+
assert_equal expected_message, error.message
37+
end
38+
39+
test "loads successfully when ruby-openai gem is available" do
40+
# This test ensures the provider loads correctly when the gem is present
41+
# Since the gem is already loaded in our test environment, this should work
42+
43+
# Verify the class exists and can be instantiated with valid config
44+
assert defined?(ActiveAgent::GenerationProvider::XAIProvider)
45+
46+
config = {
47+
"service" => "XAI",
48+
"api_key" => "test-key",
49+
"model" => "grok-2-latest"
50+
}
51+
52+
assert_nothing_raised do
53+
ActiveAgent::GenerationProvider::XAIProvider.new(config)
54+
end
55+
end
56+
57+
# Test configuration loading and presence
58+
test "raises error when xAI API key is missing" do
59+
config = {
60+
"service" => "XAI",
61+
"model" => "grok-2-latest"
62+
# Missing api_key
63+
}
64+
65+
error = assert_raises(ArgumentError) do
66+
ActiveAgent::GenerationProvider::XAIProvider.new(config)
67+
end
68+
69+
assert_includes error.message, "XAI API key is required"
70+
end
71+
72+
test "loads configuration from active_agent.yml when present" do
73+
# Mock a configuration
74+
mock_config = {
75+
"test" => {
76+
"xai" => {
77+
"service" => "XAI",
78+
"api_key" => "test-api-key",
79+
"model" => "grok-2-latest",
80+
"temperature" => 0.7
81+
}
82+
}
83+
}
84+
85+
ActiveAgent.instance_variable_set(:@config, mock_config)
86+
87+
# Set Rails environment for testing
88+
rails_env = ENV["RAILS_ENV"]
89+
ENV["RAILS_ENV"] = "test"
90+
91+
config = ApplicationAgent.configuration(:xai)
92+
93+
assert_equal "XAI", config.config["service"]
94+
assert_equal "test-api-key", config.config["api_key"]
95+
assert_equal "grok-2-latest", config.config["model"]
96+
assert_equal 0.7, config.config["temperature"]
97+
98+
# Restore original environment
99+
ENV["RAILS_ENV"] = rails_env
100+
end
101+
102+
test "loads configuration from environment-specific section" do
103+
mock_config = {
104+
"development" => {
105+
"xai" => {
106+
"service" => "XAI",
107+
"api_key" => "dev-api-key",
108+
"model" => "grok-2-latest"
109+
}
110+
},
111+
"test" => {
112+
"xai" => {
113+
"service" => "XAI",
114+
"api_key" => "test-api-key",
115+
"model" => "grok-2-latest"
116+
}
117+
}
118+
}
119+
120+
ActiveAgent.instance_variable_set(:@config, mock_config)
121+
122+
# Test development configuration
123+
original_env = ENV["RAILS_ENV"]
124+
ENV["RAILS_ENV"] = "development"
125+
126+
config = ApplicationAgent.configuration(:xai)
127+
assert_equal "dev-api-key", config.config["api_key"]
128+
129+
# Test test configuration
130+
ENV["RAILS_ENV"] = "test"
131+
config = ApplicationAgent.configuration(:xai)
132+
assert_equal "test-api-key", config.config["api_key"]
133+
134+
ENV["RAILS_ENV"] = original_env
135+
end
136+
137+
test "xAI provider initialization with API key from environment variable" do
138+
# Test with XAI_API_KEY env var
139+
original_xai_key = ENV["XAI_API_KEY"]
140+
original_grok_key = ENV["GROK_API_KEY"]
141+
142+
ENV["XAI_API_KEY"] = "env-xai-key"
143+
ENV["GROK_API_KEY"] = nil
144+
145+
config = {
146+
"service" => "XAI",
147+
"model" => "grok-2-latest"
148+
}
149+
150+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
151+
assert_equal "env-xai-key", provider.instance_variable_get(:@access_token)
152+
153+
# Test with GROK_API_KEY env var
154+
ENV["XAI_API_KEY"] = nil
155+
ENV["GROK_API_KEY"] = "env-grok-key"
156+
157+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
158+
assert_equal "env-grok-key", provider.instance_variable_get(:@access_token)
159+
160+
# Restore original environment
161+
ENV["XAI_API_KEY"] = original_xai_key
162+
ENV["GROK_API_KEY"] = original_grok_key
163+
end
164+
165+
test "xAI provider initialization with custom host" do
166+
config = {
167+
"service" => "XAI",
168+
"api_key" => "test-key",
169+
"model" => "grok-2-latest",
170+
"host" => "https://custom-xai-host.com"
171+
}
172+
173+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
174+
client = provider.instance_variable_get(:@client)
175+
176+
# The OpenAI client should be configured with the custom host
177+
assert_not_nil client
178+
end
179+
180+
test "xAI provider defaults to grok-2-latest model" do
181+
config = {
182+
"service" => "XAI",
183+
"api_key" => "test-key"
184+
# Model not specified
185+
}
186+
187+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
188+
assert_equal "grok-2-latest", provider.instance_variable_get(:@model_name)
189+
end
190+
191+
test "xAI provider uses configured model" do
192+
config = {
193+
"service" => "XAI",
194+
"api_key" => "test-key",
195+
"model" => "grok-1"
196+
}
197+
198+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
199+
assert_equal "grok-1", provider.instance_variable_get(:@model_name)
200+
end
201+
202+
test "xAI provider defaults to correct API host" do
203+
config = {
204+
"service" => "XAI",
205+
"api_key" => "test-key"
206+
}
207+
208+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
209+
assert_equal "https://api.x.ai", ActiveAgent::GenerationProvider::XAIProvider::XAI_API_HOST
210+
end
211+
212+
test "embed method raises NotImplementedError for xAI" do
213+
config = {
214+
"service" => "XAI",
215+
"api_key" => "test-key"
216+
}
217+
218+
provider = ActiveAgent::GenerationProvider::XAIProvider.new(config)
219+
mock_prompt = Minitest::Mock.new
220+
221+
error = assert_raises(NotImplementedError) do
222+
provider.embed(mock_prompt)
223+
end
224+
225+
assert_includes error.message, "xAI does not currently support embeddings"
226+
end
227+
end

0 commit comments

Comments
 (0)