Skip to content

Commit

Permalink
Upload file to Google if it's missing
Browse files Browse the repository at this point in the history
johnjosephhorton committed Sep 30, 2024
1 parent 819125f commit 2b3934b
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions edsl/inference_services/GoogleService.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
},
]


class GoogleService(InferenceServiceABC):
_inference_service_ = "google"
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
@@ -40,14 +41,14 @@ class GoogleService(InferenceServiceABC):
# @classmethod
# def available(cls) -> List[str]:
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]

@classmethod
def available(cls) -> List[str]:
model_list = []
model_list = []
for m in genai.list_models():
if "generateContent" in m.supported_generation_methods:
model_list.append(m.name.split("/")[-1])
return model_list
return model_list

@classmethod
def create_model(
@@ -87,7 +88,9 @@ def initialize(cls):
"GOOGLE_API_KEY environment variable is not set"
)
genai.configure(api_key=cls.api_token)
cls.generative_model = genai.GenerativeModel(cls._model_, safety_settings=safety_settings)
cls.generative_model = genai.GenerativeModel(
cls._model_, safety_settings=safety_settings
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -103,33 +106,45 @@ def get_generation_config(self) -> GenerationConfig:
)

async def async_execute_model_call(
self, user_prompt: str,
system_prompt: str = "",
files_list: Optional['Files'] = None
self,
user_prompt: str,
system_prompt: str = "",
files_list: Optional["Files"] = None,
) -> Dict[str, Any]:
generation_config = self.get_generation_config()

if files_list is None:
files_list = []

if system_prompt is not None and system_prompt != "" and self._model_ != "gemini-pro":
if (
system_prompt is not None
and system_prompt != ""
and self._model_ != "gemini-pro"
):
try:
self.generative_model = genai.GenerativeModel(self._model_,
safety_settings=safety_settings,
system_instruction=system_prompt)
self.generative_model = genai.GenerativeModel(
self._model_,
safety_settings=safety_settings,
system_instruction=system_prompt,
)
except InvalidArgument as e:
print(f"This model, {self._model_}, does not support system_instruction")
print(
f"This model, {self._model_}, does not support system_instruction"
)
print("Will add system_prompt to user_prompt")
user_prompt = f"{system_prompt}\n{user_prompt}"


combined_prompt = [user_prompt]
for file in files_list:
gen_ai_file = google.generativeai.types.file_types.File(file.external_locations['google'])
if "google" not in file.external_locations:
_ = file.upload_google()
gen_ai_file = google.generativeai.types.file_types.File(
file.external_locations["google"]
)
combined_prompt.append(gen_ai_file)

response = await self.generative_model.generate_content_async(combined_prompt,
generation_config=generation_config
response = await self.generative_model.generate_content_async(
combined_prompt, generation_config=generation_config
)
return response.to_dict()

0 comments on commit 2b3934b

Please sign in to comment.