diff --git a/edsl/inference_services/GoogleService.py b/edsl/inference_services/GoogleService.py index cef74b01b..dd0427e7a 100644 --- a/edsl/inference_services/GoogleService.py +++ b/edsl/inference_services/GoogleService.py @@ -28,6 +28,7 @@ }, ] + class GoogleService(InferenceServiceABC): _inference_service_ = "google" key_sequence = ["candidates", 0, "content", "parts", 0, "text"] @@ -40,14 +41,14 @@ class GoogleService(InferenceServiceABC): # @classmethod # def available(cls) -> List[str]: # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"] - + @classmethod def available(cls) -> List[str]: - model_list = [] + model_list = [] for m in genai.list_models(): if "generateContent" in m.supported_generation_methods: model_list.append(m.name.split("/")[-1]) - return model_list + return model_list @classmethod def create_model( @@ -87,7 +88,9 @@ def initialize(cls): "GOOGLE_API_KEY environment variable is not set" ) genai.configure(api_key=cls.api_token) - cls.generative_model = genai.GenerativeModel(cls._model_, safety_settings=safety_settings) + cls.generative_model = genai.GenerativeModel( + cls._model_, safety_settings=safety_settings + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -103,33 +106,45 @@ def get_generation_config(self) -> GenerationConfig: ) async def async_execute_model_call( - self, user_prompt: str, - system_prompt: str = "", - files_list: Optional['Files'] = None + self, + user_prompt: str, + system_prompt: str = "", + files_list: Optional["Files"] = None, ) -> Dict[str, Any]: generation_config = self.get_generation_config() if files_list is None: files_list = [] - if system_prompt is not None and system_prompt != "" and self._model_ != "gemini-pro": + if ( + system_prompt is not None + and system_prompt != "" + and self._model_ != "gemini-pro" + ): try: - self.generative_model = genai.GenerativeModel(self._model_, - safety_settings=safety_settings, - system_instruction=system_prompt) + self.generative_model = genai.GenerativeModel( + self._model_, + safety_settings=safety_settings, + system_instruction=system_prompt, + ) except InvalidArgument as e: - print(f"This model, {self._model_}, does not support system_instruction") + print( + f"This model, {self._model_}, does not support system_instruction" + ) print("Will add system_prompt to user_prompt") user_prompt = f"{system_prompt}\n{user_prompt}" - combined_prompt = [user_prompt] for file in files_list: - gen_ai_file = google.generativeai.types.file_types.File(file.external_locations['google']) + if "google" not in file.external_locations: + _ = file.upload_google() + gen_ai_file = google.generativeai.types.file_types.File( + file.external_locations["google"] + ) combined_prompt.append(gen_ai_file) - response = await self.generative_model.generate_content_async(combined_prompt, - generation_config=generation_config + response = await self.generative_model.generate_content_async( + combined_prompt, generation_config=generation_config ) return response.to_dict()