Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clarifai/runners/models/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def get_model_proto(self) -> resources_pb2.Model:
user_app_id=self.client.user_app_id,
model_id=self.model_id,
)
# Add secrets to additional_fields to get request-type secrets
request.additional_fields.append("secrets")
if self.model_version_id is not None:
request.version_id = self.model_version_id
resp: service_pb2.SingleModelResponse = self.client.STUB.GetModel(request)
Expand Down
17 changes: 15 additions & 2 deletions clarifai/runners/models/model_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class ModelServicer(service_pb2_grpc.V2Servicer):
This is the servicer that will handle the gRPC requests from either the dev server or runner loop.
"""

def __init__(self, model):
def __init__(self, model, model_proto=None):
"""
Args:
model: The class that will handle the model logic. Must implement predict(),
generate(), stream().
model_proto: The model proto containing model configuration including secrets.
"""
self.model = model
self.model_proto = model_proto

# Try to create auth helper from environment variables if available
self._auth_helper = None
Expand Down Expand Up @@ -55,6 +57,10 @@ def PostModelOutputs(
returns an output.
"""

# Inject model proto if available and not already in request
if self.model_proto is not None and not request.HasField("model"):
request.model.CopyFrom(self.model_proto)

# Download any urls that are not already bytes.
ensure_urls_downloaded(request, auth_helper=self._auth_helper)
inject_secrets(request)
Expand All @@ -80,6 +86,10 @@ def GenerateModelOutputs(
This is the method that will be called when the servicer is run. It takes in an input and
returns an output.
"""
# Inject model proto if available and not already in request
if self.model_proto is not None and not request.HasField("model"):
request.model.CopyFrom(self.model_proto)

# Download any urls that are not already bytes.
ensure_urls_downloaded(request, auth_helper=self._auth_helper)
inject_secrets(request)
Expand Down Expand Up @@ -108,8 +118,11 @@ def StreamModelOutputs(
# Duplicate the iterator
request, request_copy = tee(request)

# Download any urls that are not already bytes.
# Download any urls that are not already bytes and inject model proto
for req in request:
# Inject model proto if available and not already in request
if self.model_proto is not None and not req.HasField("model"):
req.model.CopyFrom(self.model_proto)
ensure_urls_downloaded(req, auth_helper=self._auth_helper)
inject_secrets(req)

Expand Down
4 changes: 3 additions & 1 deletion clarifai/runners/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def serve(

def start_servicer(self, port, pool_size, max_queue_size, max_msg_length, enable_tls):
# initialize the servicer with the runner so that it gets the predict(), generate(), stream() classes.
self._servicer = ModelServicer(self._current_model)
self._servicer = ModelServicer(
self._current_model, model_proto=self._builder.get_model_proto()
)

server = GRPCServer(
futures.ThreadPoolExecutor(
Expand Down
Loading