Skip to content

Commit

Permalink
feat: support keras file format
Browse files Browse the repository at this point in the history
Motivation

We don't currently support Keras .h5 format models.

Modifications

- Detect .h5 models in triton adapter logic and spawn a python process to convert to TF SavedModel
- Limit how many such conversion processes can run concurrently, to avoid causing OOM of adapter container or impacting other adapter/puller operation

Result

HDF5 format files can be served with Triton via conversion to TF SavedModel format.

Co-authored-by: nickhill <[email protected]>
  • Loading branch information
amnpandey and njhill committed Nov 20, 2021
1 parent f1f880a commit 7504650
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 13 deletions.
11 changes: 11 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,23 @@ LABEL name="model-serving-runtime-adapter" \
summary="Sidecar container which runs in the Model-Mesh Serving model server pods" \
description="Container which runs in each model serving pod and act as an intermediary between model-mesh and third-party model-server containers"

USER root
# install python to convert keras to tf
RUN microdnf install \
gcc \
gcc-c++ \
python38 && \
ln -sf /usr/bin/python3 /usr/bin/python && \
ln -sf /usr/bin/pip3 /usr/bin/pip && \
pip install tensorflow

USER ${USER}

# Copy over the binary and use it as the entrypoint
COPY --from=build /opt/app/puller /opt/app/
COPY --from=build /opt/app/triton-adapter /opt/app/
COPY --from=build /opt/app/mlserver-adapter /opt/app/
COPY --from=build /opt/app/model-mesh-triton-adapter/scripts/tf_pb.py /opt/scripts/

# Don't define an entrypoint. This is a multi-purpose image so the user should specify which binary they want to run (e.g. /opt/app/puller or /opt/app/triton-adapter)
# ENTRYPOINT ["/opt/app/puller"]
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ require (
go.uber.org/multierr v1.6.0 // indirect
go.uber.org/zap v1.16.0 // indirect
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect
golang.org/x/sys v0.0.0-20210423082822-04245dca01da // indirect
golang.org/x/text v0.3.6 // indirect
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9
google.golang.org/genproto v0.0.0-20210317182105-75c7a8546eb9 // indirect
google.golang.org/grpc v1.36.0
google.golang.org/protobuf v1.26.0
Expand Down
8 changes: 1 addition & 7 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,6 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4 h1:b0LrWgu8+q7z4J+0Y3Umo5q1dL7NXBkKBWkaVkAq17E=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d h1:LO7XpTYMwTqxjLcGWPijK3vRXg1aWdlNOVOHRq45d7c=
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
Expand All @@ -479,6 +477,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down Expand Up @@ -517,9 +516,6 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e h1:XNp2Flc/1eWQGk5BLzqTAN7fQIwIbfyVTuVxXxZh73M=
golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
Expand All @@ -529,8 +525,6 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
Expand Down
Binary file not shown.
34 changes: 34 additions & 0 deletions model-mesh-triton-adapter/scripts/tf_pb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2021 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tensorflow as tf
from tensorflow import keras

def export_h5_to_pb(path_to_h5, export_path):
try:
model = tf.keras.models.load_model(path_to_h5)
os.makedirs(export_path, exist_ok = True)
model.save(export_path)
except (ImportError, IOError) as e:
print('Error raised when converting keras model: \n', e, file=sys.stderr)
exit(e.errno)

print('Converting keras model to tensorflow model. Argument(s) passed: {}'.format(str(sys.argv)))
source_path = sys.argv[1]
target_path = sys.argv[2]

export_h5_to_pb(source_path, target_path)
os.remove(source_path)
print('Successfully converted keras model to tensorflow model.')
8 changes: 7 additions & 1 deletion model-mesh-triton-adapter/server/rewritemodelpath.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package server

import (
"context"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -53,7 +54,7 @@ var modelTypeToFileNameMapping = map[string]string{
"pytorch": "model.pt",
}

func rewriteModelPath(rootModelDir, modelID, modelType string, log logr.Logger) error {
func rewriteModelPath(ctx context.Context, rootModelDir, modelID, modelType string, log logr.Logger) error {
// convert to lower case and remove anything after the :
modelType = strings.ToLower(strings.Split(modelType, ":")[0])

Expand All @@ -72,6 +73,11 @@ func rewriteModelPath(rootModelDir, modelID, modelType string, log logr.Logger)
log.Error(err, "Unable to securely join", "rootModelDir", rootModelDir, "modelID", modelID)
return err
}

if err = convertKerasToTF(sourceModelIDDir, ctx, log); err != nil {
return fmt.Errorf("Error while converting keras model %s to tensorflow: %w", sourceModelIDDir, err)
}

files, err := ioutil.ReadDir(sourceModelIDDir)
if err != nil {
return fmt.Errorf("Could not read files in dir %s: %w", sourceModelIDDir, err)
Expand Down
7 changes: 5 additions & 2 deletions model-mesh-triton-adapter/server/rewritemodelpath_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package server

import (
"context"
"encoding/json"
"io/ioutil"
"os"
Expand Down Expand Up @@ -156,7 +157,8 @@ func TestRewriteModelPath(t *testing.T) {
tt.generateSourceDirectory(t)

// run function under test
err = rewriteModelPath(generatedTestdataDir, tt.ModelID, tt.InputModelType, log)
ctx := context.Background()
err = rewriteModelPath(ctx, generatedTestdataDir, tt.ModelID, tt.InputModelType, log)

if tt.ExpectError && err == nil {
t.Fatal("ExpectError is true, but no error was returned")
Expand Down Expand Up @@ -186,8 +188,9 @@ func TestRewriteModelPathMultiple(t *testing.T) {
}

// next run the function under test for all the models
ctx := context.Background()
for _, tt := range rewriteModelPathTests {
err = rewriteModelPath(generatedTestdataDir, tt.ModelID, tt.InputModelType, log)
err = rewriteModelPath(ctx, generatedTestdataDir, tt.ModelID, tt.InputModelType, log)
if tt.ExpectError && err == nil {
t.Fatal("ExpectError is true, but no error was returned")
}
Expand Down
2 changes: 1 addition & 1 deletion model-mesh-triton-adapter/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (s *TritonAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMode
}
}

err := rewriteModelPath(s.AdapterConfig.RootModelDir, req.ModelId, modelType, log)
err := rewriteModelPath(ctx, s.AdapterConfig.RootModelDir, req.ModelId, modelType, log)

if err != nil {
log.Error(err, "Failed to create model directory and load model")
Expand Down
88 changes: 88 additions & 0 deletions model-mesh-triton-adapter/server/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,35 @@
package server

import (
"bufio"
"context"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strconv"

"github.com/go-logr/logr"
"golang.org/x/sync/semaphore"

triton "github.com/kserve/modelmesh-runtime-adapter/internal/proto/triton"
"google.golang.org/protobuf/encoding/prototext"
)

var sem *semaphore.Weighted

func init() {
if m, ok := os.LookupEnv("MAX_CONC_KERAS_CONV_PROCS"); !ok {
sem = semaphore.NewWeighted(2) // default
} else if n, err := strconv.Atoi(m); err != nil {
sem = semaphore.NewWeighted(int64(n))
} else {
panic("MAX_CONC_KERAS_CONV_PROCS env var must have int value")
}
}

func writeConfigPbtxt(filename string, modelConfig *triton.ModelConfig) error {
var err error

Expand All @@ -39,3 +61,69 @@ func writeConfigPbtxt(filename string, modelConfig *triton.ModelConfig) error {
}
return nil
}

func convertKerasToTF(sourceModelIDDir string, ctx context.Context, loggr logr.Logger) error {
// check if keras and return the file name
kerasFile, err := checkAndReturnModelFile(sourceModelIDDir)
if err != nil {
return err
}
if kerasFile == "" {
//not a keras model
return nil
}
targetPath := filepath.Join(sourceModelIDDir, "model.savedmodel")
cmd := exec.Command("python", "/opt/scripts/tf_pb.py", kerasFile, targetPath)
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("Failed to create stdout pipe: %w ", err)
}

if err = sem.Acquire(ctx, 1); err != nil {
return fmt.Errorf("Failed to acquire semaphore for keras conversion process: %w", err)
}
defer sem.Release(1)

if err = cmd.Start(); err != nil {
return fmt.Errorf("Failed to start python process for keras model conversion: %w ", err)
}
go copyOutput(stdout, loggr)

err = cmd.Wait()
if exitErr, ok := err.(*exec.ExitError); ok && len(exitErr.Stderr) != 0 {
loggr.Error(err, "keras model conversion failed: %s", exitErr.Stderr)
return fmt.Errorf("keras model conversion failed: %s: %w", exitErr.Stderr, err)
} else if err != nil {
loggr.Error(err, "keras model conversion failed")
return fmt.Errorf("keras model conversion failed: %w", err)
}
return nil
}

func copyOutput(r io.Reader, loggr logr.Logger) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
loggr.Info(scanner.Text())
}
}

func checkAndReturnModelFile(sourceModelIDDir string) (string, error) {
files, err := ioutil.ReadDir(sourceModelIDDir)
if err != nil {
return "", fmt.Errorf("could not read files in dir %s: %w", sourceModelIDDir, err)
}
var modelFilePath string
var extFiles []string
for _, file := range files {
if filepath.Ext(file.Name()) == ".h5" {
modelFilePath = filepath.Join(sourceModelIDDir, file.Name())
} else if file.Name() != "_schema.json" {
extFiles = append(extFiles, file.Name())
}
}
if modelFilePath != "" && len(extFiles) != 0 {
return "", fmt.Errorf("model dir contains other files in addition to a keras model %s: %v",
modelFilePath, extFiles)
}
return modelFilePath, nil
}

0 comments on commit 7504650

Please sign in to comment.