Skip to content

Commit

Permalink
feat: TorchServe support (#34)
Browse files Browse the repository at this point in the history
#### Motivation

The Triton runtime can be used with model-mesh to serve PyTorch torchscript models, but it does not support arbitrary PyTorch models i.e. eager mode. KServe "classic" has integration with TorchServe but it would be good to have integration with model-mesh too so that these kinds of models can be used in distributed multi-model serving contexts.

#### Modifications

- Add adapter logic to implement the modelmesh management SPI using the torchserve gRPC management API
- Build and include new adapter binary in the docker image
- Add mock server and basic unit tests

Implementation notes:
- Model size (mem usage) is not returned from the `LoadModel` RPC but rather done separately in the `ModelSize` rpc (so that the model is available for use slightly sooner)
- TorchServe's `DescribeModel` RPC is used to determine the model's memory usage. If that isn't successful it falls back to using a multiple of the model size on disk (similar to other runtimes)
- The adapter writes the config file for TorchServe to consume

TorchServe does not yet support the KServe V2 gRPC prediction API (only REST) which means that can't currently be used with model-mesh. The native TorchServe gRPC inference interface can be used instead for the time being.

A smaller PR to the main modelmesh-serving controller repo will be opened to enable use of TorchServe, which will include the ServingRuntime specification.

#### Result

TorchServe can be used seamlessly with ModelMesh Serving to serve PyTorch models, including eager mode.

Resolves #4
Contributes to kserve/modelmesh-serving#63

Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill authored Sep 23, 2022
1 parent f4c43a3 commit 9a61ddc
Show file tree
Hide file tree
Showing 18 changed files with 3,366 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ RUN go build -o puller model-serving-puller/main.go
RUN go build -o triton-adapter model-mesh-triton-adapter/main.go
RUN go build -o mlserver-adapter model-mesh-mlserver-adapter/main.go
RUN go build -o ovms-adapter model-mesh-ovms-adapter/main.go

RUN go build -o torchserve-adapter model-mesh-torchserve-adapter/main.go

###############################################################################
# Stage 3: Copy build assets to create the smallest final runtime image
Expand Down Expand Up @@ -121,6 +121,8 @@ COPY --from=build /opt/app/triton-adapter /opt/app/
COPY --from=build /opt/app/mlserver-adapter /opt/app/
COPY --from=build /opt/app/model-mesh-triton-adapter/scripts/tf_pb.py /opt/scripts/
COPY --from=build /opt/app/ovms-adapter /opt/app/
COPY --from=build /opt/app/torchserve-adapter /opt/app/


# Don't define an entrypoint. This is a multi-purpose image so the user should specify which binary they want to run (e.g. /opt/app/puller or /opt/app/triton-adapter)
# ENTRYPOINT ["/opt/app/puller"]
330 changes: 330 additions & 0 deletions internal/proto/torchserve/inference.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions internal/proto/torchserve/inference.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copied from https://github.com/pytorch/serve/blob/8c23585d2453f230c411721028ad4b07e58cc7dd/frontend/server/src/main/resources/proto/inference.proto

syntax = "proto3";

package org.pytorch.serve.grpc.inference;

import "google/protobuf/empty.proto";

option java_multiple_files = true;

message PredictionsRequest {
// Name of model.
string model_name = 1; //required

// Version of model to run prediction on.
string model_version = 2; //optional

// input data for model prediction
map<string, bytes> input = 3; //required
}

message PredictionResponse {
// TorchServe health
bytes prediction = 1;
}

message TorchServeHealthResponse {
// TorchServe health
string health = 1;
}

service InferenceAPIsService {
rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {}

// Predictions entry point to get inference using default model version.
rpc Predictions(PredictionsRequest) returns (PredictionResponse) {}
}
Loading

0 comments on commit 9a61ddc

Please sign in to comment.