diff --git a/vertexai/preview/prompts.py b/vertexai/preview/prompts.py index a8b5667a46..afda8ebe33 100644 --- a/vertexai/preview/prompts.py +++ b/vertexai/preview/prompts.py @@ -16,7 +16,21 @@ from vertexai.prompts._prompts import ( Prompt, ) +from vertexai.prompts._prompt_management import ( + create_version, + delete, + get, + list_prompts as list, + list_versions, + restore_version, +) __all__ = [ "Prompt", + "delete", + "create_version", + "get", + "list", + "list_versions", + "restore_version", ] diff --git a/vertexai/prompts/_prompt_management.py b/vertexai/prompts/_prompt_management.py new file mode 100644 index 0000000000..580328e97c --- /dev/null +++ b/vertexai/prompts/_prompt_management.py @@ -0,0 +1,739 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform.compat.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import ( + dataset_version as gca_dataset_version, +) +from google.cloud.aiplatform_v1beta1.types import ( + prediction_service as gapic_prediction_service_types, +) +from vertexai.generative_models import ( + Part, + Image, + GenerativeModel, + Tool, + ToolConfig, +) +from vertexai.generative_models._generative_models import ( + _proto_to_dict, + _dict_to_proto, + _tool_types_to_gapic_tools, + PartsType, +) +from vertexai.prompts._prompts import Prompt +from google.protobuf import field_mask_pb2 as field_mask + +import dataclasses +from typing import ( + Any, + Dict, + Optional, +) + +_LOGGER = base.Logger(__name__) +_dataset_client_value = None + +DEFAULT_API_SCHEMA_VERSION = "1.0.0" +PROMPT_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml" +) + + +def _format_function_declaration_parameters(obj: Any): + """Recursively replaces type_ and format_ fields in-place.""" + if isinstance(obj, (str, int, float)): + return obj + if isinstance(obj, dict): + new = obj.__class__() + for key, value in obj.items(): + key = key.replace("type_", "type") + key = key.replace("format_", "format") + new[key] = _format_function_declaration_parameters(value) + elif isinstance(obj, (list, set, tuple)): + new = obj.__class__( + _format_function_declaration_parameters(value) for value in obj + ) + else: + return obj + return new + + +@dataclasses.dataclass +class Arguments: + """Arguments. Child of Execution. + + Attributes: + variables: The arguments of the execution. + """ + + variables: dict[str, list[Part]] + + def to_dict(self) -> Dict[str, Any]: + dct = {} + for variable_name in self.variables: + dct[variable_name] = { + "partList": { + "parts": [part.to_dict() for part in self.variables[variable_name]] + } + } + return dct + + @classmethod + def from_dict(cls, dct: Dict[str, Any]) -> "Arguments": + variables = {} + for variable_name in dct: + variables[variable_name] = [ + Part.from_dict(part) for part in dct[variable_name]["partList"]["parts"] + ] + arguments = cls(variables=variables) + return arguments + + +@dataclasses.dataclass +class Execution: + """Execution. Child of MultimodalPrompt. + + Attributes: + arguments: The arguments of the execution. + """ + + arguments: Arguments + + def __init__(self, arguments: dict[str, list[Part]]): + self.arguments = Arguments(variables=arguments) + + def to_dict(self) -> Dict[str, Any]: + dct = {} + dct["arguments"] = self.arguments.to_dict() + return dct + + @classmethod + def from_dict(cls, dct: Dict[str, Any]) -> "Execution": + arguments = dct.get("arguments", None) + execution = cls(arguments=arguments) + return execution + + +@dataclasses.dataclass +class MultimodalPrompt: + """MultimodalPrompt. Child of PromptDatasetMetadata. + + Attributes: + prompt_message: The schema for the prompt. A subset of the GenerateContentRequest schema. + api_schema_version: The api schema version of the prompt when it was last modified. + executions: Contains data related to an execution of a prompt (ex. variables) + """ + + prompt_message: gapic_prediction_service_types.GenerateContentRequest + api_schema_version: Optional[str] = DEFAULT_API_SCHEMA_VERSION + executions: Optional[list[Execution]] = None + + def to_dict(self) -> Dict[str, Any]: + dct = {"multimodalPrompt": {}} + dct["apiSchemaVersion"] = self.api_schema_version + dct["multimodalPrompt"]["promptMessage"] = _proto_to_dict(self.prompt_message) + + # Fix type_ and format_ fields + if dct["multimodalPrompt"]["promptMessage"].get("tools", None): + tools = dct["multimodalPrompt"]["promptMessage"]["tools"] + for tool in tools: + for function_declaration in tool.get("function_declarations", []): + function_declaration[ + "parameters" + ] = _format_function_declaration_parameters( + function_declaration["parameters"] + ) + + if self.executions and self.executions[0]: + # Only add variable sets if they are non empty. + execution_dcts = [] + for execution in self.executions: + exeuction_dct = execution.to_dict() + if exeuction_dct and exeuction_dct["arguments"]: + execution_dcts.append(exeuction_dct) + if execution_dcts: + dct["executions"] = execution_dcts + return dct + + @classmethod + def from_dict(cls, dct: Dict[str, Any]) -> "MultimodalPrompt": + api_schema_version = dct.get("apiSchemaVersion", DEFAULT_API_SCHEMA_VERSION) + if int(api_schema_version.split(".")[0]) > int( + DEFAULT_API_SCHEMA_VERSION.split(".")[0] + ): + # Disallow loading prompts with lower major schema version + raise ValueError( + "This prompt was saved with a newer schema version and cannot be loaded." + ) + prompt_message_dct = dct.get("multimodalPrompt", {}).get("promptMessage", None) + if not prompt_message_dct: + raise ValueError("This prompt is not supported in the SDK.") + # Tool function declaration will fail the proto conversion + tools = prompt_message_dct.get("tools", None) + if tools: + tools = [Tool.from_dict(tool) for tool in tools] + prompt_message_dct.pop("tools") + prompt_message = _dict_to_proto( + gapic_prediction_service_types.GenerateContentRequest, prompt_message_dct + ) + if tools: + # Convert Tools to gapic to store in the prompt_message + prompt_message.tools = _tool_types_to_gapic_tools(tools) + executions_dct = dct.get("executions", []) + executions = [Execution.from_dict(execution) for execution in executions_dct] + if not executions: + executions = None + multimodal_prompt = cls( + prompt_message=prompt_message, + api_schema_version=api_schema_version, + executions=executions, + ) + return multimodal_prompt + + +@dataclasses.dataclass +class PromptDatasetMetadata: + """PromptDatasetMetadata. + + Attributes: + prompt_type: Required. SDK only supports "freeform" or "multimodal_freeform" + prompt_api_schema: Required. SDK only supports multimodalPrompt + """ + + prompt_type: str + prompt_api_schema: MultimodalPrompt + + def to_dict(self) -> Dict[str, Any]: + dct = {} + dct["promptType"] = self.prompt_type + dct["promptApiSchema"] = self.prompt_api_schema.to_dict() + return dct + + @classmethod + def from_dict(cls, dct: Dict[str, Any]) -> "PromptDatasetMetadata": + metadata = cls( + prompt_type=dct.get("promptType", None), + prompt_api_schema=MultimodalPrompt.from_dict( + dct.get("promptApiSchema", None) + ), + ) + return metadata + + +@dataclasses.dataclass +class PromptMetadata: + """Metadata containing the display name and prompt id of a prompt. + + Returned by the `list_prompts` method. + + Attributes: + name: The display name of the prompt version. + prompt_id: The id of the prompt. + """ + + name: str + prompt_id: str + + +@dataclasses.dataclass +class PromptVersionMetadata: + """Metadata containing the display name, prompt id, and version id of a prompt version. + + Returned by the `list_prompt_versions` method. + + Attributes: + name: The display name of the prompt version. + prompt_id: The id of the prompt. + version_id: The version id of the prompt. + """ + + name: str + prompt_id: str + version_id: str + + +def create_version( + prompt: Prompt, + prompt_id: Optional[str] = None, + version_name: Optional[str] = None, +) -> Prompt: + """Creates a Prompt or Prompt Version in the online prompt store + + Args: + prompt: The Prompt object to create a new version of. + prompt_id: The id of the prompt resource to create a new version under. + If it is not provided and the prompt has no prompt resource + associated with it, a new prompt resource will be created. + version_name: Optional display name of the new prompt version. + If not specified, a default name including a timestamp will be used. + + Returns: + A new Prompt object with a reference to the newly created or updated + prompt resource. This new Prompt object is nearly identical to the + original Prompt object, except it has references to the new + prompt version. + """ + if not (prompt_id or prompt._dataset): + # Case 1: Neither prompt id nor prompt._dataset exists, so we + # create a new prompt resource + return _create_prompt_resource(prompt=prompt, version_name=version_name) + + # Case 2: No prompt_id override is given, so we update the existing prompt resource + if not prompt_id: + return _create_prompt_version_resource(prompt=prompt, version_name=version_name) + + # Case 3: Save a new version to the prompt_id provided as an arg + # prompt_id is guaranteed to exist due to Cases 1 & 2 being handled before + + # Store the original prompt resource name, if it exists + original_prompt_resource = None if not prompt._dataset else prompt._dataset.name + + # Create a gapic dataset object if it doesn't exist + if not prompt._dataset: + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}" + dataset_metadata = _format_dataset_metadata_dict(prompt=prompt) + + prompt._dataset = gca_dataset.Dataset( + name=name, + display_name=prompt.prompt_name or "Untitled Prompt", + metadata_schema_uri=PROMPT_SCHEMA_URI, + metadata=dataset_metadata, + model_reference=prompt.model_name, + ) + + # Override the dataset prompt id with the new prompt id + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + prompt._dataset.name = ( + f"projects/{project}/locations/{location}/datasets/{prompt_id}" + ) + result = _create_prompt_version_resource(prompt=prompt, version_name=version_name) + + # Restore the original prompt resource name. This is a no-op if there + # was no original prompt resource name. + prompt._dataset.name = original_prompt_resource + return result + + +def _check_multimodal_contents(prompt_data: PartsType): + if isinstance(prompt_data, Image): + return "multimodal_freeform" + elif isinstance(prompt_data, list): + for part in prompt_data: + check = _check_multimodal_contents(part) + if check == "multimodal_freeform": + return "multimodal_freeform" + elif isinstance(prompt_data, Part): + if "text" not in prompt_data._raw_part: + return "multimodal_freeform" + return "freeform" + + +def _format_dataset_metadata_dict(prompt: Prompt) -> dict[str, Any]: + """Helper function to convert the configs and prompt data stored in the Prompt object to a dataset metadata dict.""" + model = GenerativeModel(model_name=prompt.model_name) + prompt_message = model._prepare_request( + contents=prompt.prompt_data or "temporary data", + model=prompt.model_name, + system_instruction=prompt.system_instruction, + tools=prompt.tools, + tool_config=prompt.tool_config, + safety_settings=prompt.safety_settings, + generation_config=prompt.generation_config, + ) + # Remove temporary contents + if not prompt.prompt_data: + prompt_message.contents = None + + # Stopgap solution to check for multimodal contents to set flag for UI + if prompt.prompt_data: + prompt_type = _check_multimodal_contents(prompt.prompt_data) + else: + prompt_type = "freeform" + + return PromptDatasetMetadata( + prompt_type=prompt_type, + prompt_api_schema=MultimodalPrompt( + prompt_message=prompt_message, + executions=[Execution(variable_set) for variable_set in prompt.variables], + ), + ).to_dict() + + +def _create_dataset(prompt: Prompt, parent: str) -> gca_dataset.Dataset: + dataset_metadata = _format_dataset_metadata_dict(prompt=prompt) + dataset = gca_dataset.Dataset( + name=parent, + display_name=prompt.prompt_name or "Untitled Prompt", + metadata_schema_uri=PROMPT_SCHEMA_URI, + metadata=dataset_metadata, + model_reference=prompt.model_name, + ) + operation = prompt._dataset_client.create_dataset( + parent=parent, + dataset=dataset, + ) + dataset = operation.result() + + # Purge labels + dataset.labels = None + return dataset + + +def _create_dataset_version( + prompt: Prompt, parent: str, version_name: Optional[str] = None +): + dataset_version = gca_dataset_version.DatasetVersion( + display_name=version_name, + ) + + dataset_version = prompt._dataset_client.create_dataset_version( + parent=parent, + dataset_version=dataset_version, + ) + return dataset_version.result() + + +def _update_dataset( + prompt: Prompt, + dataset: gca_dataset.Dataset, +) -> gca_dataset_version.DatasetVersion: + dataset.metadata = _format_dataset_metadata_dict(prompt=prompt) + + mask_paths = ["modelReference", "metadata"] + if dataset.display_name != "Untitled Prompt": + mask_paths.append("displayName") + + updated_dataset = prompt._dataset_client.update_dataset( + dataset=dataset, + update_mask=field_mask.FieldMask(paths=mask_paths), + ) + # Remove etag to avoid error for repeated dataset updates + updated_dataset.etag = None + return updated_dataset + + +def _create_prompt_resource( + prompt: Prompt, version_name: Optional[str] = None +) -> Prompt: + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + + # Step 1: Create prompt dataset API call + parent = f"projects/{project}/locations/{location}" + dataset = _create_dataset(prompt=prompt, parent=parent) + + # Step 2: Create prompt version API call + dataset_version = _create_dataset_version( + prompt=prompt, + parent=dataset.name, + version_name=version_name, + ) + + # Step 3: Create new Prompt object to return + new_prompt = Prompt._clone(prompt=prompt) + new_prompt._dataset = dataset + new_prompt._version_id = dataset_version.name.split("/")[-1] + new_prompt._version_name = dataset_version.display_name + prompt_id = new_prompt._dataset.name.split("/")[5] + + _LOGGER.info( + f"Created prompt resource with id {prompt_id} with version number {new_prompt._version_id}" + ) + return new_prompt + + +def _create_prompt_version_resource( + prompt: Prompt, + version_name: Optional[str] = None, +) -> Prompt: + # Step 1: Update prompt API call + updated_dataset = _update_dataset(prompt=prompt, dataset=prompt._dataset) + + # Step 2: Create prompt version API call + dataset_version = _create_dataset_version( + prompt=prompt, + parent=updated_dataset.name, + version_name=version_name, + ) + + # Step 3: Create new Prompt object to return + new_prompt = Prompt._clone(prompt=prompt) + new_prompt._dataset = updated_dataset + new_prompt._version_id = dataset_version.name.split("/")[-1] + new_prompt._version_name = dataset_version.display_name + prompt_id = prompt._dataset.name.split("/")[5] + + _LOGGER.info( + f"Updated prompt resource with id {prompt_id} as version number {new_prompt._version_id}" + ) + return new_prompt + + +def _get_prompt_resource(prompt: Prompt, prompt_id: str) -> gca_dataset.Dataset: + """Helper function to get a prompt resource from a prompt id.""" + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}" + dataset = prompt._dataset_client.get_dataset(name=name) + return dataset + + +def _get_prompt_resource_from_version( + prompt: Prompt, prompt_id: str, version_id: str +) -> gca_dataset.Dataset: + """Helper function to get a prompt resource from a prompt version id.""" + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}" + + # Step 1: Get dataset version object + dataset_version = prompt._dataset_client.get_dataset_version(name=name) + prompt._version_name = dataset_version.display_name + + # Step 2: Fetch dataset object to get the dataset display name + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}" + dataset = prompt._dataset_client.get_dataset(name=name) + + # Step 3: Convert to DatasetVersion object to Dataset object + dataset = gca_dataset.Dataset( + name=name, + display_name=dataset.display_name, + metadata_schema_uri=PROMPT_SCHEMA_URI, + metadata=dataset_version.metadata, + model_reference=dataset_version.model_reference, + ) + return dataset + + +def restore_version(prompt_id: str, version_id: str) -> PromptVersionMetadata: + """Restores a previous version of the prompt resource and + loads that version into the current Prompt object. + + Args: + prompt_id: The id of the prompt resource to restore a version of. + version_id: The version id of the online prompt resource. + """ + + # Step 1: Make restore dataset version API call + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}/datasetVersions/{version_id}" + + # Create a temporary Prompt object for a dataset client + temp_prompt = Prompt() + operation = temp_prompt._dataset_client.restore_dataset_version(name=name) + result = operation.result() + new_version_id = result.name.split("/")[-1] + prompt_id = result.name.split("/")[5] + + _LOGGER.info( + f"Restored prompt version {version_id} under prompt id {prompt_id} as version number {new_version_id}" + ) + + # Step 2: Create PromptVersionMetadata object to return + return PromptVersionMetadata( + name=result.display_name, + prompt_id=result.name.split("/")[5], + version_id=new_version_id, + ) + + +def get(prompt_id: str, version_id: Optional[str] = None) -> Prompt: + """Creates a Prompt object from an online resource. + + Args: + prompt_id: The id of the prompt resource. + version_id: Optional version id of the prompt resource. + If not specified, the latest version will be used. + + Returns: + A prompt loaded from the online resource as a `Prompt` object. + """ + prompt = Prompt() + if version_id: + dataset = _get_prompt_resource_from_version( + prompt=prompt, + prompt_id=prompt_id, + version_id=version_id, + ) + else: + dataset = _get_prompt_resource(prompt=prompt, prompt_id=prompt_id) + + # Remove etag to avoid error for repeated dataset updates + dataset.etag = None + + prompt._dataset = dataset + prompt._version_id = version_id + + dataset_dict = _proto_to_dict(dataset) + + metadata = PromptDatasetMetadata.from_dict(dataset_dict["metadata"]) + _populate_fields_from_metadata(prompt=prompt, metadata=metadata) + return prompt + + +def _populate_fields_from_metadata( + prompt: Prompt, metadata: PromptDatasetMetadata +) -> None: + """Helper function to populate Promptfields from metadata object""" + # Populate model_name (Required, raw deserialized type is str) + prompt.model_name = metadata.prompt_api_schema.prompt_message.model + + # Populate prompt_data (raw deserialized type is list[Content]) + contents = metadata.prompt_api_schema.prompt_message.contents + if contents: + if len(contents) > 1: + raise ValueError("Multi-turn prompts are not supported yet.") + prompt_data = [Part._from_gapic(part) for part in list(contents[0].parts)] + + # Unwrap single text part into str + if len(prompt_data) == 1 and "text" in prompt_data[0]._raw_part: + prompt.prompt_data = prompt_data[0].text + else: + prompt.prompt_data = prompt_data + + # Populate system_instruction (raw deserialized type is single Content) + system_instruction = metadata.prompt_api_schema.prompt_message.system_instruction + if system_instruction: + system_instruction_parts = [ + Part._from_gapic(part) for part in list(system_instruction.parts) + ] + # Unwrap single text part into str + if len(system_instruction_parts) == 1 and system_instruction_parts[0].text: + prompt.system_instruction = system_instruction_parts[0].text + else: + prompt.system_instruction = system_instruction_parts + + # Populate variables + executions = metadata.prompt_api_schema.executions + variables = [] + if executions: + for execution in executions: + serialized_variable_set = execution.arguments + variable_set = {} + if serialized_variable_set: + for name, value in serialized_variable_set.variables.items(): + # Parts are dicts, not gapic objects for variables + variable_set[name] = [ + Part.from_dict(part) + for part in list(value["partList"]["parts"]) + ] + variables.append(variable_set) + + # Unwrap variable single text part into str + for variable_set in variables: + for name, value in variable_set.items(): + if len(value) == 1 and "text" in value[0]._raw_part: + variable_set[name] = value[0].text + prompt.variables = variables + + # Populate generation_config (raw deserialized type is GenerationConfig) + generation_config = metadata.prompt_api_schema.prompt_message.generation_config + if generation_config: + prompt.generation_config = generation_config + + # Populate safety_settings (raw deserialized type is RepeatedComposite of SafetySetting) + safety_settings = metadata.prompt_api_schema.prompt_message.safety_settings + if safety_settings: + prompt.safety_settings = list(safety_settings) + + # Populate tools (raw deserialized type is RepeatedComposite of Tool) + tools = metadata.prompt_api_schema.prompt_message.tools + if tools: + prompt.tools = list(tools) + + # Populate tool_config (raw deserialized type is ToolConfig) + tool_config = metadata.prompt_api_schema.prompt_message.tool_config + if tool_config: + prompt.tool_config = ToolConfig._from_gapic(tool_config) + + +def list_prompts() -> list[PromptMetadata]: + """Lists all prompt resources in the online prompt store associated with the project.""" + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + parent = f"projects/{project}/locations/{location}" + + # Create a temporary Prompt object for a dataset client + temp_prompt = Prompt() + prompts_pager = temp_prompt._dataset_client.list_datasets( + parent=parent, + ) + prompts_list = [] + for prompt in prompts_pager: + prompts_list.append( + PromptMetadata( + name=prompt.display_name, + prompt_id=prompt.name.split("/")[5], + ) + ) + return prompts_list + + +def list_versions(prompt_id: str) -> list[PromptVersionMetadata]: + """Returns a list of PromptVersionMetadata objects for the prompt resource. + + Args: + prompt_id: The id of the prompt resource to list versions of. + + Returns: + A list of PromptVersionMetadata objects for the prompt resource. + """ + # Create a temporary Prompt object for a dataset client + temp_prompt = Prompt() + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + parent = f"projects/{project}/locations/{location}/datasets/{prompt_id}" + + versions_pager = temp_prompt._dataset_client.list_dataset_versions( + parent=parent, + ) + version_history = [] + for version in versions_pager: + version_history.append( + PromptVersionMetadata( + name=version.display_name, + prompt_id=version.name.split("/")[5], + version_id=version.name.split("/")[-1], + ) + ) + return version_history + + +def delete(prompt_id: str) -> None: + """Deletes the online prompt resource associated with the prompt id.""" + + # Create a temporary Prompt object for a dataset client + temp_prompt = Prompt() + project = aiplatform_initializer.global_config.project + location = aiplatform_initializer.global_config.location + name = f"projects/{project}/locations/{location}/datasets/{prompt_id}" + + operation = temp_prompt._dataset_client.delete_dataset( + name=name, + ) + operation.result() + + _LOGGER.info(f"Deleted prompt resource with id {prompt_id}.") diff --git a/vertexai/prompts/_prompts.py b/vertexai/prompts/_prompts.py index c70e37ba2f..104d59d828 100644 --- a/vertexai/prompts/_prompts.py +++ b/vertexai/prompts/_prompts.py @@ -15,8 +15,11 @@ # limitations under the License. # +from copy import deepcopy + from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform.compat.services import dataset_service_client from vertexai.generative_models import ( Content, Image, @@ -51,6 +54,7 @@ _LOGGER = base.Logger(__name__) +DEFAULT_MODEL_NAME = "gemini-1.5-flash-002" VARIABLE_NAME_REGEX = r"(\{[^\W0-9]\w*\})" @@ -126,9 +130,10 @@ class Prompt: def __init__( self, - prompt_data: PartsType, + prompt_data: Optional[PartsType] = None, *, variables: Optional[List[Dict[str, PartsType]]] = None, + prompt_name: Optional[str] = None, generation_config: Optional[GenerationConfig] = None, model_name: Optional[str] = None, safety_settings: Optional[SafetySetting] = None, @@ -141,9 +146,11 @@ def __init__( Args: prompt: A PartsType prompt which may be a template with variables or a prompt with no variables. variables: A list of dictionaries containing the variable names and values. + prompt_name: The name of the prompt if stored in an online resource. generation_config: A GenerationConfig object containing parameters for generation. model_name: Model Garden model resource name. Alternatively, a tuned model endpoint resource name can be provided. + If no model is provided, the default latest model will be used. safety_settings: A SafetySetting object containing safety settings for generation. system_instruction: A PartsType object representing the system instruction. tools: A list of Tool objects for function calling. @@ -158,8 +165,16 @@ def __init__( self._tools = None self._tool_config = None + # Prompt Management + self._dataset_client_value = None + self._dataset = None + self._prompt_name = None + self._version_id = None + self._version_name = None + self.prompt_data = prompt_data self.variables = variables if variables else [{}] + self.prompt_name = prompt_name self.model_name = model_name self.generation_config = generation_config self.safety_settings = safety_settings @@ -168,20 +183,27 @@ def __init__( self.tool_config = tool_config @property - def prompt_data(self) -> PartsType: + def prompt_data(self) -> Optional[PartsType]: return self._prompt_data @property def variables(self) -> Optional[List[Dict[str, PartsType]]]: return self._variables + @property + def prompt_name(self) -> Optional[str]: + return self._prompt_name + @property def generation_config(self) -> Optional[GenerationConfig]: return self._generation_config @property def model_name(self) -> Optional[str]: - return self._model_name + if self._model_name: + return self._model_name + else: + return Prompt._format_model_resource_name(DEFAULT_MODEL_NAME) @property def safety_settings(self) -> Optional[List[SafetySetting]]: @@ -199,6 +221,20 @@ def tools(self) -> Optional[List[Tool]]: def tool_config(self) -> Optional[ToolConfig]: return self._tool_config + @property + def prompt_id(self) -> Optional[str]: + if self._dataset: + return self._dataset.name.split("/")[-1] + return None + + @property + def version_id(self) -> Optional[str]: + return self._version_id + + @property + def version_name(self) -> Optional[str]: + return self._version_name + @prompt_data.setter def prompt_data(self, prompt_data: PartsType) -> None: """Overwrites the existing saved local prompt_data. @@ -206,7 +242,8 @@ def prompt_data(self, prompt_data: PartsType) -> None: Args: prompt_data: A PartsType prompt. """ - self._validate_parts_type_data(prompt_data) + if prompt_data is not None: + self._validate_parts_type_data(prompt_data) self._prompt_data = prompt_data @variables.setter @@ -226,6 +263,14 @@ def variables(self, variables: List[Dict[str, PartsType]]) -> None: f"Variables must be a list of dictionaries, not {type(variables)}" ) + @prompt_name.setter + def prompt_name(self, prompt_name: Optional[str]) -> None: + """Overwrites the existing saved local prompt_name.""" + if prompt_name: + self._prompt_name = prompt_name + else: + self._prompt_name = None + @model_name.setter def model_name(self, model_name: Optional[str]) -> None: """Overwrites the existing saved local model_name.""" @@ -370,6 +415,10 @@ def assemble_contents(self, **variables_dict: PartsType) -> List[Content]: ) ``` """ + # If prompt_data is None, throw an error. + if self.prompt_data is None: + raise ValueError("prompt_data must not be empty.") + variables_dict = variables_dict.copy() # If there are no variables, return the prompt_data as a Content object. @@ -541,10 +590,15 @@ def generate_content( ) ``` """ + if not (model_name or self._model_name): + _LOGGER.info( + "No model name specified, falling back to default model: %s", + self.model_name, + ) + model_name = model_name or self.model_name generation_config = generation_config or self.generation_config safety_settings = safety_settings or self.safety_settings - model_name = model_name or self.model_name tools = tools or self.tools tool_config = tool_config or self.tool_config system_instruction = system_instruction or self.system_instruction @@ -567,14 +621,69 @@ def generate_content( stream=stream, ) + @property + def _dataset_client(self) -> dataset_service_client.DatasetServiceClient: + if not getattr(self, "_dataset_client_value", None): + self._dataset_client_value = ( + aiplatform_initializer.global_config.create_client( + client_class=dataset_service_client.DatasetServiceClient, + ) + ) + return self._dataset_client_value + + @classmethod + def _clone(cls, prompt: "Prompt") -> "Prompt": + """Returns a copy of the Prompt.""" + return Prompt( + prompt_data=prompt.prompt_data, + variables=deepcopy(prompt.variables), + generation_config=deepcopy(prompt.generation_config), + safety_settings=deepcopy(prompt.safety_settings), + tools=deepcopy(prompt.tools), + tool_config=deepcopy(prompt.tool_config), + system_instruction=prompt.system_instruction, + model_name=prompt.model_name, + ) + def get_unassembled_prompt_data(self) -> PartsType: """Returns the prompt data, without any variables replaced.""" return self.prompt_data def __str__(self) -> str: """Returns the prompt data as a string, without any variables replaced.""" - return str(self.prompt_data) + return str(self.prompt_data or "") def __repr__(self) -> str: """Returns a string representation of the unassembled prompt.""" - return f"Prompt(prompt_data='{self.prompt_data}', variables={self.variables})" + result = "Prompt(" + if self.prompt_data: + result += f"prompt_data='{self.prompt_data}', " + if self.variables and self.variables[0]: + result += f"variables={self.variables}), " + if self.system_instruction: + result += f"system_instruction={self.system_instruction}), " + if self._model_name: + # Don't display default model in repr + result += f"model_name={self._model_name}), " + if self.generation_config: + result += f"generation_config={self.generation_config}), " + if self.safety_settings: + result += f"safety_settings={self.safety_settings}), " + if self.tools: + result += f"tools={self.tools}), " + if self.tool_config: + result += f"tool_config={self.tool_config}, " + if self.prompt_id: + result += f"prompt_id={self.prompt_id}, " + if self.version_id: + result += f"version_id={self.version_id}, " + if self.prompt_name: + result += f"prompt_name={self.prompt_name}, " + if self.version_name: + result += f"version_name={self.version_name}, " + + # Remove trailing ", " + if result[-2:] == ", ": + result = result[:-2] + result += ")" + return result