Skip to content

Commit e77a442

Browse files
committed
♻️ Refactor JSON parsing with JsonSchema implementation
Remove String conversion implementations for response types This refactoring improves type safety and validation by: - Adding JsonSchema trait to response types for better validation - Removing custom From<String> implementations that were error-prone - Enhancing JSON parsing with provider-specific handling - Adding fallback extraction methods for malformed responses - Creating specialized parsing functions for different provider outputs The changes affect core LLM interaction code and response types while maintaining the same functionality with more robust error handling.
1 parent 08debd3 commit e77a442

File tree

4 files changed

+93
-71
lines changed

4 files changed

+93
-71
lines changed

src/changes/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::config::Config;
66
use crate::git::GitRepo;
77
use crate::llm;
88
use anyhow::{Context, Result};
9+
use schemars::JsonSchema;
910
use serde::Serialize;
1011
use serde::de::DeserializeOwned;
1112
use std::fmt::Debug;
@@ -24,8 +25,7 @@ pub async fn generate_changes_content<T>(
2425
create_user_prompt: UserPromptFn,
2526
) -> Result<T>
2627
where
27-
T: DeserializeOwned + Serialize + Debug,
28-
String: Into<T>,
28+
T: DeserializeOwned + Serialize + Debug + JsonSchema,
2929
{
3030
// Create ChangeAnalyzer with Arc<GitRepo>
3131
let analyzer = ChangeAnalyzer::new(git_repo.clone())?;

src/commit/review.rs

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -248,37 +248,6 @@ pub struct GeneratedReview {
248248
pub best_practices: Option<DimensionAnalysis>,
249249
}
250250

251-
impl From<String> for GeneratedReview {
252-
fn from(s: String) -> Self {
253-
match serde_json::from_str(&s) {
254-
Ok(review) => review,
255-
Err(e) => {
256-
crate::log_debug!("Failed to parse review JSON: {}", e);
257-
crate::log_debug!("Input was: {}", s);
258-
Self {
259-
summary: "Error parsing code review".to_string(),
260-
code_quality: "There was an error parsing the code review from the AI."
261-
.to_string(),
262-
suggestions: vec!["Please try again.".to_string()],
263-
issues: vec![],
264-
positive_aspects: vec![],
265-
complexity: None,
266-
abstraction: None,
267-
deletion: None,
268-
hallucination: None,
269-
style: None,
270-
security: None,
271-
performance: None,
272-
duplication: None,
273-
error_handling: None,
274-
testing: None,
275-
best_practices: None,
276-
}
277-
}
278-
}
279-
}
280-
}
281-
282251
impl GeneratedReview {
283252
/// Formats a location string to ensure it includes file reference when possible
284253
///

src/commit/types.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,6 @@ pub struct GeneratedMessage {
1313
pub message: String,
1414
}
1515

16-
impl From<String> for GeneratedMessage {
17-
fn from(s: String) -> Self {
18-
match serde_json::from_str(&s) {
19-
Ok(message) => message,
20-
Err(e) => {
21-
eprintln!("Failed to parse JSON: {e}\nInput was: {s}");
22-
Self {
23-
emoji: None,
24-
title: "Error parsing commit message".to_string(),
25-
message: "There was an error parsing the commit message from the AI. Please try again.".to_string(),
26-
}
27-
}
28-
}
29-
}
30-
}
31-
3216
/// Formats a commit message from a `GeneratedMessage`
3317
pub fn format_commit_message(response: &GeneratedMessage) -> String {
3418
let mut message = String::new();

src/llm.rs

Lines changed: 91 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use llm::{
66
builder::{LLMBackend, LLMBuilder},
77
chat::ChatMessage,
88
};
9-
use serde::Serialize;
9+
use schemars::JsonSchema;
1010
use serde::de::DeserializeOwned;
1111
use std::collections::HashMap;
1212
use std::str::FromStr;
@@ -22,8 +22,7 @@ pub async fn get_message<T>(
2222
user_prompt: &str,
2323
) -> Result<T>
2424
where
25-
T: Serialize + DeserializeOwned + std::fmt::Debug,
26-
String: Into<T>,
25+
T: DeserializeOwned + JsonSchema,
2726
{
2827
log_debug!("Generating message using provider: {}", provider_name);
2928
log_debug!("System prompt: {}", system_prompt);
@@ -83,19 +82,17 @@ where
8382
.map_err(|e| anyhow!("Failed to build provider: {}", e))?;
8483

8584
// Generate the message
86-
let result = get_message_with_provider::<T>(provider, user_prompt).await?;
87-
88-
Ok(result)
85+
get_message_with_provider(provider, user_prompt, provider_name).await
8986
}
9087

9188
/// Generates a message using the given provider (mainly for testing purposes)
9289
pub async fn get_message_with_provider<T>(
9390
provider: Box<dyn LLMProvider + Send + Sync>,
9491
user_prompt: &str,
92+
provider_type: &str,
9593
) -> Result<T>
9694
where
97-
T: Serialize + DeserializeOwned + std::fmt::Debug,
98-
String: Into<T>,
95+
T: DeserializeOwned + JsonSchema,
9996
{
10097
log_debug!("Entering get_message_with_provider");
10198

@@ -104,26 +101,61 @@ where
104101
let result = Retry::spawn(retry_strategy, || async {
105102
log_debug!("Attempting to generate message");
106103

104+
// Enhanced prompt that requests specifically formatted JSON output
105+
let enhanced_prompt = if std::any::type_name::<T>() != std::any::type_name::<String>() {
106+
format!("{}\n\nPlease respond with a valid JSON object and nothing else. No explanations or text outside the JSON.", user_prompt)
107+
} else {
108+
user_prompt.to_string()
109+
};
110+
107111
// Create chat message with user prompt
108-
let messages = vec![ChatMessage::user().content(user_prompt.to_string()).build()];
112+
let mut messages = vec![ChatMessage::user().content(enhanced_prompt).build()];
113+
114+
// Special handling for Anthropic - use the "prefill" technique with "{"
115+
if provider_type.to_lowercase() == "anthropic" && std::any::type_name::<T>() != std::any::type_name::<String>() {
116+
messages.push(ChatMessage::assistant().content("Here is the JSON:\n{").build());
117+
}
109118

110119
match tokio::time::timeout(Duration::from_secs(30), provider.chat(&messages)).await {
111120
Ok(Ok(response)) => {
112121
log_debug!("Received response from provider");
113122
let response_text = response.text().unwrap_or_default();
114-
let cleaned_message = clean_json_from_llm(&response_text);
115-
116-
if std::any::type_name::<T>() == std::any::type_name::<String>() {
117-
// If T is String, return the raw string response
118-
Ok(cleaned_message.into())
119-
} else {
120-
// Attempt to deserialize the response
121-
match serde_json::from_str::<T>(&cleaned_message) {
122-
Ok(message) => Ok(message),
123-
Err(e) => {
124-
log_debug!("Deserialization error: {} message: {}", e, cleaned_message);
125-
Err(anyhow!("Deserialization error: {}", e))
123+
124+
// Provider-specific response parsing
125+
let result = match provider_type.to_lowercase().as_str() {
126+
// For Anthropic with brace prefixing
127+
"anthropic" => {
128+
if std::any::type_name::<T>() == std::any::type_name::<String>() {
129+
// For String type, we need to handle differently
130+
#[allow(clippy::unnecessary_to_owned)]
131+
let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
132+
.map_err(|e| anyhow!("String conversion error: {}", e))?;
133+
Ok(string_result)
134+
} else {
135+
parse_json_response_with_brace_prefix::<T>(&response_text)
126136
}
137+
},
138+
139+
// For all other providers - use appropriate parsing
140+
_ => {
141+
if std::any::type_name::<T>() == std::any::type_name::<String>() {
142+
// For String type, we need to handle differently
143+
#[allow(clippy::unnecessary_to_owned)]
144+
let string_result: T = serde_json::from_value(serde_json::Value::String(response_text.clone()))
145+
.map_err(|e| anyhow!("String conversion error: {}", e))?;
146+
Ok(string_result)
147+
} else {
148+
// First try direct parsing, then fall back to extraction
149+
parse_json_response::<T>(&response_text)
150+
}
151+
}
152+
};
153+
154+
match result {
155+
Ok(message) => Ok(message),
156+
Err(e) => {
157+
log_debug!("JSON parse error: {} text: {}", e, response_text);
158+
Err(anyhow!("JSON parse error: {}", e))
127159
}
128160
}
129161
}
@@ -141,7 +173,7 @@ where
141173

142174
match result {
143175
Ok(message) => {
144-
log_debug!("Deserialized message: {:?}", message);
176+
log_debug!("Generated message successfully");
145177
Ok(message)
146178
}
147179
Err(e) => {
@@ -151,6 +183,43 @@ where
151183
}
152184
}
153185

186+
/// Parse a provider's response that should be pure JSON
187+
fn parse_json_response<T: DeserializeOwned>(text: &str) -> Result<T> {
188+
match serde_json::from_str::<T>(text) {
189+
Ok(message) => Ok(message),
190+
Err(e) => {
191+
// Fallback to a more robust extraction if direct parsing fails
192+
log_debug!(
193+
"Direct JSON parse failed: {}. Attempting fallback extraction.",
194+
e
195+
);
196+
extract_and_parse_json(text)
197+
}
198+
}
199+
}
200+
201+
/// Parse a response from Anthropic that needs the prefixed "{"
202+
fn parse_json_response_with_brace_prefix<T: DeserializeOwned>(text: &str) -> Result<T> {
203+
// Add the opening brace that we prefilled in the prompt
204+
let json_text = format!("{{{}", text);
205+
match serde_json::from_str::<T>(&json_text) {
206+
Ok(message) => Ok(message),
207+
Err(e) => {
208+
log_debug!(
209+
"Brace-prefixed JSON parse failed: {}. Attempting fallback extraction.",
210+
e
211+
);
212+
extract_and_parse_json(text)
213+
}
214+
}
215+
}
216+
217+
/// Extracts and parses JSON from a potentially non-JSON response
218+
fn extract_and_parse_json<T: DeserializeOwned>(text: &str) -> Result<T> {
219+
let cleaned_json = clean_json_from_llm(text);
220+
serde_json::from_str(&cleaned_json).map_err(|e| anyhow!("JSON parse error: {}", e))
221+
}
222+
154223
/// Returns a list of available LLM providers as strings
155224
pub fn get_available_provider_names() -> Vec<String> {
156225
vec![

0 commit comments

Comments
 (0)