@@ -6,7 +6,7 @@ use llm::{
66 builder:: { LLMBackend , LLMBuilder } ,
77 chat:: ChatMessage ,
88} ;
9- use serde :: Serialize ;
9+ use schemars :: JsonSchema ;
1010use serde:: de:: DeserializeOwned ;
1111use std:: collections:: HashMap ;
1212use std:: str:: FromStr ;
@@ -22,8 +22,7 @@ pub async fn get_message<T>(
2222 user_prompt : & str ,
2323) -> Result < T >
2424where
25- T : Serialize + DeserializeOwned + std:: fmt:: Debug ,
26- String : Into < T > ,
25+ T : DeserializeOwned + JsonSchema ,
2726{
2827 log_debug ! ( "Generating message using provider: {}" , provider_name) ;
2928 log_debug ! ( "System prompt: {}" , system_prompt) ;
@@ -83,19 +82,17 @@ where
8382 . map_err ( |e| anyhow ! ( "Failed to build provider: {}" , e) ) ?;
8483
8584 // Generate the message
86- let result = get_message_with_provider :: < T > ( provider, user_prompt) . await ?;
87-
88- Ok ( result)
85+ get_message_with_provider ( provider, user_prompt, provider_name) . await
8986}
9087
9188/// Generates a message using the given provider (mainly for testing purposes)
9289pub async fn get_message_with_provider < T > (
9390 provider : Box < dyn LLMProvider + Send + Sync > ,
9491 user_prompt : & str ,
92+ provider_type : & str ,
9593) -> Result < T >
9694where
97- T : Serialize + DeserializeOwned + std:: fmt:: Debug ,
98- String : Into < T > ,
95+ T : DeserializeOwned + JsonSchema ,
9996{
10097 log_debug ! ( "Entering get_message_with_provider" ) ;
10198
@@ -104,26 +101,61 @@ where
104101 let result = Retry :: spawn ( retry_strategy, || async {
105102 log_debug ! ( "Attempting to generate message" ) ;
106103
104+ // Enhanced prompt that requests specifically formatted JSON output
105+ let enhanced_prompt = if std:: any:: type_name :: < T > ( ) != std:: any:: type_name :: < String > ( ) {
106+ format ! ( "{}\n \n Please respond with a valid JSON object and nothing else. No explanations or text outside the JSON." , user_prompt)
107+ } else {
108+ user_prompt. to_string ( )
109+ } ;
110+
107111 // Create chat message with user prompt
108- let messages = vec ! [ ChatMessage :: user( ) . content( user_prompt. to_string( ) ) . build( ) ] ;
112+ let mut messages = vec ! [ ChatMessage :: user( ) . content( enhanced_prompt) . build( ) ] ;
113+
114+ // Special handling for Anthropic - use the "prefill" technique with "{"
115+ if provider_type. to_lowercase ( ) == "anthropic" && std:: any:: type_name :: < T > ( ) != std:: any:: type_name :: < String > ( ) {
116+ messages. push ( ChatMessage :: assistant ( ) . content ( "Here is the JSON:\n {" ) . build ( ) ) ;
117+ }
109118
110119 match tokio:: time:: timeout ( Duration :: from_secs ( 30 ) , provider. chat ( & messages) ) . await {
111120 Ok ( Ok ( response) ) => {
112121 log_debug ! ( "Received response from provider" ) ;
113122 let response_text = response. text ( ) . unwrap_or_default ( ) ;
114- let cleaned_message = clean_json_from_llm ( & response_text) ;
115-
116- if std:: any:: type_name :: < T > ( ) == std:: any:: type_name :: < String > ( ) {
117- // If T is String, return the raw string response
118- Ok ( cleaned_message. into ( ) )
119- } else {
120- // Attempt to deserialize the response
121- match serde_json:: from_str :: < T > ( & cleaned_message) {
122- Ok ( message) => Ok ( message) ,
123- Err ( e) => {
124- log_debug ! ( "Deserialization error: {} message: {}" , e, cleaned_message) ;
125- Err ( anyhow ! ( "Deserialization error: {}" , e) )
123+
124+ // Provider-specific response parsing
125+ let result = match provider_type. to_lowercase ( ) . as_str ( ) {
126+ // For Anthropic with brace prefixing
127+ "anthropic" => {
128+ if std:: any:: type_name :: < T > ( ) == std:: any:: type_name :: < String > ( ) {
129+ // For String type, we need to handle differently
130+ #[ allow( clippy:: unnecessary_to_owned) ]
131+ let string_result: T = serde_json:: from_value ( serde_json:: Value :: String ( response_text. clone ( ) ) )
132+ . map_err ( |e| anyhow ! ( "String conversion error: {}" , e) ) ?;
133+ Ok ( string_result)
134+ } else {
135+ parse_json_response_with_brace_prefix :: < T > ( & response_text)
126136 }
137+ } ,
138+
139+ // For all other providers - use appropriate parsing
140+ _ => {
141+ if std:: any:: type_name :: < T > ( ) == std:: any:: type_name :: < String > ( ) {
142+ // For String type, we need to handle differently
143+ #[ allow( clippy:: unnecessary_to_owned) ]
144+ let string_result: T = serde_json:: from_value ( serde_json:: Value :: String ( response_text. clone ( ) ) )
145+ . map_err ( |e| anyhow ! ( "String conversion error: {}" , e) ) ?;
146+ Ok ( string_result)
147+ } else {
148+ // First try direct parsing, then fall back to extraction
149+ parse_json_response :: < T > ( & response_text)
150+ }
151+ }
152+ } ;
153+
154+ match result {
155+ Ok ( message) => Ok ( message) ,
156+ Err ( e) => {
157+ log_debug ! ( "JSON parse error: {} text: {}" , e, response_text) ;
158+ Err ( anyhow ! ( "JSON parse error: {}" , e) )
127159 }
128160 }
129161 }
@@ -141,7 +173,7 @@ where
141173
142174 match result {
143175 Ok ( message) => {
144- log_debug ! ( "Deserialized message: {:?}" , message ) ;
176+ log_debug ! ( "Generated message successfully" ) ;
145177 Ok ( message)
146178 }
147179 Err ( e) => {
@@ -151,6 +183,43 @@ where
151183 }
152184}
153185
186+ /// Parse a provider's response that should be pure JSON
187+ fn parse_json_response < T : DeserializeOwned > ( text : & str ) -> Result < T > {
188+ match serde_json:: from_str :: < T > ( text) {
189+ Ok ( message) => Ok ( message) ,
190+ Err ( e) => {
191+ // Fallback to a more robust extraction if direct parsing fails
192+ log_debug ! (
193+ "Direct JSON parse failed: {}. Attempting fallback extraction." ,
194+ e
195+ ) ;
196+ extract_and_parse_json ( text)
197+ }
198+ }
199+ }
200+
201+ /// Parse a response from Anthropic that needs the prefixed "{"
202+ fn parse_json_response_with_brace_prefix < T : DeserializeOwned > ( text : & str ) -> Result < T > {
203+ // Add the opening brace that we prefilled in the prompt
204+ let json_text = format ! ( "{{{}" , text) ;
205+ match serde_json:: from_str :: < T > ( & json_text) {
206+ Ok ( message) => Ok ( message) ,
207+ Err ( e) => {
208+ log_debug ! (
209+ "Brace-prefixed JSON parse failed: {}. Attempting fallback extraction." ,
210+ e
211+ ) ;
212+ extract_and_parse_json ( text)
213+ }
214+ }
215+ }
216+
217+ /// Extracts and parses JSON from a potentially non-JSON response
218+ fn extract_and_parse_json < T : DeserializeOwned > ( text : & str ) -> Result < T > {
219+ let cleaned_json = clean_json_from_llm ( text) ;
220+ serde_json:: from_str ( & cleaned_json) . map_err ( |e| anyhow ! ( "JSON parse error: {}" , e) )
221+ }
222+
154223/// Returns a list of available LLM providers as strings
155224pub fn get_available_provider_names ( ) -> Vec < String > {
156225 vec ! [
0 commit comments