diff --git a/clarifai/runners/models/model_builder.py b/clarifai/runners/models/model_builder.py index 41297a22..87024eb8 100644 --- a/clarifai/runners/models/model_builder.py +++ b/clarifai/runners/models/model_builder.py @@ -225,6 +225,8 @@ def get_model_proto(self) -> resources_pb2.Model: user_app_id=self.client.user_app_id, model_id=self.model_id, ) + # Add secrets to additional_fields to get request-type secrets + request.additional_fields.append("secrets") if self.model_version_id is not None: request.version_id = self.model_version_id resp: service_pb2.SingleModelResponse = self.client.STUB.GetModel(request) diff --git a/clarifai/runners/models/model_servicer.py b/clarifai/runners/models/model_servicer.py index a9715a77..d84d2b39 100644 --- a/clarifai/runners/models/model_servicer.py +++ b/clarifai/runners/models/model_servicer.py @@ -18,13 +18,15 @@ class ModelServicer(service_pb2_grpc.V2Servicer): This is the servicer that will handle the gRPC requests from either the dev server or runner loop. """ - def __init__(self, model): + def __init__(self, model, model_proto=None): """ Args: model: The class that will handle the model logic. Must implement predict(), generate(), stream(). + model_proto: The model proto containing model configuration including secrets. """ self.model = model + self.model_proto = model_proto # Try to create auth helper from environment variables if available self._auth_helper = None @@ -55,6 +57,10 @@ def PostModelOutputs( returns an output. """ + # Inject model proto if available and not already in request + if self.model_proto is not None and not request.HasField("model"): + request.model.CopyFrom(self.model_proto) + # Download any urls that are not already bytes. ensure_urls_downloaded(request, auth_helper=self._auth_helper) inject_secrets(request) @@ -80,6 +86,10 @@ def GenerateModelOutputs( This is the method that will be called when the servicer is run. It takes in an input and returns an output. """ + # Inject model proto if available and not already in request + if self.model_proto is not None and not request.HasField("model"): + request.model.CopyFrom(self.model_proto) + # Download any urls that are not already bytes. ensure_urls_downloaded(request, auth_helper=self._auth_helper) inject_secrets(request) @@ -108,8 +118,11 @@ def StreamModelOutputs( # Duplicate the iterator request, request_copy = tee(request) - # Download any urls that are not already bytes. + # Download any urls that are not already bytes and inject model proto for req in request: + # Inject model proto if available and not already in request + if self.model_proto is not None and not req.HasField("model"): + req.model.CopyFrom(self.model_proto) ensure_urls_downloaded(req, auth_helper=self._auth_helper) inject_secrets(req) diff --git a/clarifai/runners/server.py b/clarifai/runners/server.py index 21a462e7..3f27e37f 100644 --- a/clarifai/runners/server.py +++ b/clarifai/runners/server.py @@ -232,7 +232,9 @@ def serve( def start_servicer(self, port, pool_size, max_queue_size, max_msg_length, enable_tls): # initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes. - self._servicer = ModelServicer(self._current_model) + self._servicer = ModelServicer( + self._current_model, model_proto=self._builder.get_model_proto() + ) server = GRPCServer( futures.ThreadPoolExecutor(