@@ -236,9 +236,28 @@ async def generate(
236236 timeout : float = 180.0 , # Increased timeout for more complete responses (3 minutes)
237237 repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
238238 top_k : int = 80 , # Added top_k parameter for better quality
239- do_sample : bool = True # Added do_sample parameter
239+ do_sample : bool = True , # Added do_sample parameter
240+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
240241 ) -> str :
241- """Generate text using the model with improved error handling"""
242+ """
243+ Generate text using the model with improved error handling.
244+
245+ Args:
246+ prompt: The prompt to generate text from
247+ model_id: Optional model ID to use
248+ stream: Whether to stream the response
249+ max_length: Maximum length of the generated text
250+ temperature: Temperature for sampling
251+ top_p: Top-p for nucleus sampling
252+ timeout: Request timeout in seconds
253+ repetition_penalty: Penalty for repetition (higher values = less repetition)
254+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
255+ do_sample: Whether to use sampling instead of greedy decoding
256+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
257+
258+ Returns:
259+ The generated text as a string.
260+ """
242261 # Update activity timestamp
243262 self ._update_activity ()
244263
@@ -254,6 +273,10 @@ async def generate(
254273 "do_sample" : do_sample
255274 }
256275
276+ # Add max_time parameter if provided
277+ if max_time is not None :
278+ payload ["max_time" ] = max_time
279+
257280 if stream :
258281 return self .stream_generate (
259282 prompt = prompt ,
@@ -311,7 +334,8 @@ async def stream_generate(
311334 retry_count : int = 3 , # Increased retry count for better reliability
312335 repetition_penalty : float = 1.15 , # Increased repetition penalty for better quality
313336 top_k : int = 80 , # Added top_k parameter for better quality
314- do_sample : bool = True # Added do_sample parameter
337+ do_sample : bool = True , # Added do_sample parameter
338+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
315339 ) -> AsyncGenerator [str , None ]:
316340 """
317341 Stream text generation with token-level streaming and robust error handling.
@@ -326,6 +350,8 @@ async def stream_generate(
326350 retry_count: Number of retries for network errors
327351 repetition_penalty: Penalty for repetition (higher values = less repetition)
328352 top_k: Top-k for sampling (higher values = more diverse vocabulary)
353+ do_sample: Whether to use sampling instead of greedy decoding
354+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
329355
330356 Returns:
331357 A generator that yields chunks of text as they are generated.
@@ -349,6 +375,10 @@ async def stream_generate(
349375 "do_sample" : do_sample
350376 }
351377
378+ # Add max_time parameter if provided
379+ if max_time is not None :
380+ payload ["max_time" ] = max_time
381+
352382 # Create a timeout for this specific request
353383 request_timeout = aiohttp .ClientTimeout (total = timeout )
354384
@@ -473,9 +503,27 @@ async def chat(
473503 top_p : float = 0.9 ,
474504 timeout : float = 180.0 , # Increased timeout for more complete responses (3 minutes)
475505 repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
476- top_k : int = 80 # Added top_k parameter for better quality
506+ top_k : int = 80 , # Added top_k parameter for better quality
507+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
477508 ) -> Dict [str , Any ]:
478- """Chat completion endpoint with improved error handling"""
509+ """
510+ Chat completion endpoint with improved error handling.
511+
512+ Args:
513+ messages: List of message dictionaries with 'role' and 'content' keys
514+ model_id: Optional model ID to use
515+ stream: Whether to stream the response
516+ max_length: Maximum length of the generated text
517+ temperature: Temperature for sampling
518+ top_p: Top-p for nucleus sampling
519+ timeout: Request timeout in seconds
520+ repetition_penalty: Penalty for repetition (higher values = less repetition)
521+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
522+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
523+
524+ Returns:
525+ The chat completion response as a dictionary.
526+ """
479527 # Update activity timestamp
480528 self ._update_activity ()
481529
@@ -490,6 +538,10 @@ async def chat(
490538 "top_k" : top_k
491539 }
492540
541+ # Add max_time parameter if provided
542+ if max_time is not None :
543+ payload ["max_time" ] = max_time
544+
493545 if stream :
494546 return self .stream_chat (
495547 messages = messages ,
@@ -538,9 +590,27 @@ async def stream_chat(
538590 timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
539591 retry_count : int = 3 , # Increased retry count for better reliability
540592 repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
541- top_k : int = 80 # Added top_k parameter for better quality
593+ top_k : int = 80 , # Added top_k parameter for better quality
594+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
542595 ) -> AsyncGenerator [Dict [str , Any ], None ]:
543- """Stream chat completion with robust error handling"""
596+ """
597+ Stream chat completion with robust error handling.
598+
599+ Args:
600+ messages: List of message dictionaries with 'role' and 'content' keys
601+ model_id: Optional model ID to use
602+ max_length: Maximum length of the generated text
603+ temperature: Temperature for sampling
604+ top_p: Top-p for nucleus sampling
605+ timeout: Request timeout in seconds
606+ retry_count: Number of retries for network errors
607+ repetition_penalty: Penalty for repetition (higher values = less repetition)
608+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
609+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
610+
611+ Returns:
612+ A generator that yields chunks of the chat completion response.
613+ """
544614 # Update activity timestamp
545615 self ._update_activity ()
546616
@@ -555,6 +625,10 @@ async def stream_chat(
555625 "top_k" : top_k
556626 }
557627
628+ # Add max_time parameter if provided
629+ if max_time is not None :
630+ payload ["max_time" ] = max_time
631+
558632 # Create a timeout for this specific request
559633 request_timeout = aiohttp .ClientTimeout (total = timeout )
560634
@@ -661,9 +735,26 @@ async def batch_generate(
661735 top_p : float = 0.9 ,
662736 timeout : float = 300.0 , # Increased timeout for more complete responses (5 minutes)
663737 repetition_penalty : float = 1.15 , # Added repetition penalty for better quality
664- top_k : int = 80 # Added top_k parameter for better quality
738+ top_k : int = 80 , # Added top_k parameter for better quality
739+ max_time : Optional [float ] = None # Added max_time parameter to limit generation time
665740 ) -> Dict [str , List [str ]]:
666- """Generate text for multiple prompts in parallel with improved error handling"""
741+ """
742+ Generate text for multiple prompts in parallel with improved error handling.
743+
744+ Args:
745+ prompts: List of prompts to generate text from
746+ model_id: Optional model ID to use
747+ max_length: Maximum length of the generated text
748+ temperature: Temperature for sampling
749+ top_p: Top-p for nucleus sampling
750+ timeout: Request timeout in seconds
751+ repetition_penalty: Penalty for repetition (higher values = less repetition)
752+ top_k: Top-k for sampling (higher values = more diverse vocabulary)
753+ max_time: Optional maximum time in seconds to spend generating (server-side timeout, defaults to 180 seconds if not provided)
754+
755+ Returns:
756+ Dictionary with the generated responses.
757+ """
667758 # Update activity timestamp
668759 self ._update_activity ()
669760
@@ -677,6 +768,10 @@ async def batch_generate(
677768 "top_k" : top_k
678769 }
679770
771+ # Add max_time parameter if provided
772+ if max_time is not None :
773+ payload ["max_time" ] = max_time
774+
680775 # Create a timeout for this specific request
681776 request_timeout = aiohttp .ClientTimeout (total = timeout )
682777
0 commit comments