diff --git a/Dockerfile b/Dockerfile index 705fdc0e..d66c1383 100644 --- a/Dockerfile +++ b/Dockerfile @@ -102,12 +102,23 @@ LABEL name="model-serving-runtime-adapter" \ summary="Sidecar container which runs in the Model-Mesh Serving model server pods" \ description="Container which runs in each model serving pod and act as an intermediary between model-mesh and third-party model-server containers" +USER root +# install python to convert keras to tf +RUN microdnf install \ + gcc \ + gcc-c++ \ + python38 && \ + ln -sf /usr/bin/python3 /usr/bin/python && \ + ln -sf /usr/bin/pip3 /usr/bin/pip && \ + pip install tensorflow + USER ${USER} # Copy over the binary and use it as the entrypoint COPY --from=build /opt/app/puller /opt/app/ COPY --from=build /opt/app/triton-adapter /opt/app/ COPY --from=build /opt/app/mlserver-adapter /opt/app/ +COPY --from=build /opt/app/model-mesh-triton-adapter/scripts/tf_pb.py /opt/scripts/ # Don't define an entrypoint. This is a multi-purpose image so the user should specify which binary they want to run (e.g. /opt/app/puller or /opt/app/triton-adapter) # ENTRYPOINT ["/opt/app/puller"] diff --git a/go.mod b/go.mod index 4e99a1ec..13396e65 100644 --- a/go.mod +++ b/go.mod @@ -16,8 +16,7 @@ require ( go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.16.0 // indirect golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect - golang.org/x/sys v0.0.0-20210423082822-04245dca01da // indirect - golang.org/x/text v0.3.6 // indirect + golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 google.golang.org/genproto v0.0.0-20210317182105-75c7a8546eb9 // indirect google.golang.org/grpc v1.36.0 google.golang.org/protobuf v1.26.0 diff --git a/go.sum b/go.sum index 717266d0..91cbf33d 100644 --- a/go.sum +++ b/go.sum @@ -464,8 +464,6 @@ golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4 h1:b0LrWgu8+q7z4J+0Y3Umo5q1dL7NXBkKBWkaVkAq17E= -golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d h1:LO7XpTYMwTqxjLcGWPijK3vRXg1aWdlNOVOHRq45d7c= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -479,6 +477,7 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -517,9 +516,6 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e h1:XNp2Flc/1eWQGk5BLzqTAN7fQIwIbfyVTuVxXxZh73M= -golang.org/x/sys v0.0.0-20210317225723-c4fcb01b228e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -529,8 +525,6 @@ golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3 golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/model-mesh-triton-adapter/examples/models/tfmnistnew/model.h5 b/model-mesh-triton-adapter/examples/models/tfmnistnew/model.h5 new file mode 100644 index 00000000..7478f6b5 Binary files /dev/null and b/model-mesh-triton-adapter/examples/models/tfmnistnew/model.h5 differ diff --git a/model-mesh-triton-adapter/scripts/tf_pb.py b/model-mesh-triton-adapter/scripts/tf_pb.py new file mode 100644 index 00000000..1d57c6a3 --- /dev/null +++ b/model-mesh-triton-adapter/scripts/tf_pb.py @@ -0,0 +1,34 @@ +# Copyright 2021 IBM Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import tensorflow as tf +from tensorflow import keras + +def export_h5_to_pb(path_to_h5, export_path): + try: + model = tf.keras.models.load_model(path_to_h5) + os.makedirs(export_path, exist_ok = True) + model.save(export_path) + except (ImportError, IOError) as e: + print('Error raised when converting keras model: \n', e, file=sys.stderr) + exit(e.errno) + +print('Converting keras model to tensorflow model. Argument(s) passed: {}'.format(str(sys.argv))) +source_path = sys.argv[1] +target_path = sys.argv[2] + +export_h5_to_pb(source_path, target_path) +os.remove(source_path) +print('Successfully converted keras model to tensorflow model.') diff --git a/model-mesh-triton-adapter/server/rewritemodelpath.go b/model-mesh-triton-adapter/server/rewritemodelpath.go index e6dcb036..44bfac1b 100644 --- a/model-mesh-triton-adapter/server/rewritemodelpath.go +++ b/model-mesh-triton-adapter/server/rewritemodelpath.go @@ -14,6 +14,7 @@ package server import ( + "context" "errors" "fmt" "io/ioutil" @@ -53,7 +54,7 @@ var modelTypeToFileNameMapping = map[string]string{ "pytorch": "model.pt", } -func rewriteModelPath(rootModelDir, modelID, modelType string, log logr.Logger) error { +func rewriteModelPath(ctx context.Context, rootModelDir, modelID, modelType string, log logr.Logger) error { // convert to lower case and remove anything after the : modelType = strings.ToLower(strings.Split(modelType, ":")[0]) @@ -72,6 +73,11 @@ func rewriteModelPath(rootModelDir, modelID, modelType string, log logr.Logger) log.Error(err, "Unable to securely join", "rootModelDir", rootModelDir, "modelID", modelID) return err } + + if err = convertKerasToTF(sourceModelIDDir, ctx, log); err != nil { + return fmt.Errorf("Error while converting keras model %s to tensorflow: %w", sourceModelIDDir, err) + } + files, err := ioutil.ReadDir(sourceModelIDDir) if err != nil { return fmt.Errorf("Could not read files in dir %s: %w", sourceModelIDDir, err) diff --git a/model-mesh-triton-adapter/server/rewritemodelpath_test.go b/model-mesh-triton-adapter/server/rewritemodelpath_test.go index 95b858cf..a08c8c13 100644 --- a/model-mesh-triton-adapter/server/rewritemodelpath_test.go +++ b/model-mesh-triton-adapter/server/rewritemodelpath_test.go @@ -14,6 +14,7 @@ package server import ( + "context" "encoding/json" "io/ioutil" "os" @@ -156,7 +157,8 @@ func TestRewriteModelPath(t *testing.T) { tt.generateSourceDirectory(t) // run function under test - err = rewriteModelPath(generatedTestdataDir, tt.ModelID, tt.InputModelType, log) + ctx := context.Background() + err = rewriteModelPath(ctx, generatedTestdataDir, tt.ModelID, tt.InputModelType, log) if tt.ExpectError && err == nil { t.Fatal("ExpectError is true, but no error was returned") @@ -186,8 +188,9 @@ func TestRewriteModelPathMultiple(t *testing.T) { } // next run the function under test for all the models + ctx := context.Background() for _, tt := range rewriteModelPathTests { - err = rewriteModelPath(generatedTestdataDir, tt.ModelID, tt.InputModelType, log) + err = rewriteModelPath(ctx, generatedTestdataDir, tt.ModelID, tt.InputModelType, log) if tt.ExpectError && err == nil { t.Fatal("ExpectError is true, but no error was returned") } diff --git a/model-mesh-triton-adapter/server/server.go b/model-mesh-triton-adapter/server/server.go index 1d99a793..31e6dabb 100644 --- a/model-mesh-triton-adapter/server/server.go +++ b/model-mesh-triton-adapter/server/server.go @@ -107,7 +107,7 @@ func (s *TritonAdapterServer) LoadModel(ctx context.Context, req *mmesh.LoadMode } } - err := rewriteModelPath(s.AdapterConfig.RootModelDir, req.ModelId, modelType, log) + err := rewriteModelPath(ctx, s.AdapterConfig.RootModelDir, req.ModelId, modelType, log) if err != nil { log.Error(err, "Failed to create model directory and load model") diff --git a/model-mesh-triton-adapter/server/utils.go b/model-mesh-triton-adapter/server/utils.go index c0325d0d..20fdcc96 100644 --- a/model-mesh-triton-adapter/server/utils.go +++ b/model-mesh-triton-adapter/server/utils.go @@ -14,13 +14,35 @@ package server import ( + "bufio" + "context" "fmt" + "io" "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + + "github.com/go-logr/logr" + "golang.org/x/sync/semaphore" triton "github.com/kserve/modelmesh-runtime-adapter/internal/proto/triton" "google.golang.org/protobuf/encoding/prototext" ) +var sem *semaphore.Weighted + +func init() { + if m, ok := os.LookupEnv("MAX_CONC_KERAS_CONV_PROCS"); !ok { + sem = semaphore.NewWeighted(2) // default + } else if n, err := strconv.Atoi(m); err != nil { + sem = semaphore.NewWeighted(int64(n)) + } else { + panic("MAX_CONC_KERAS_CONV_PROCS env var must have int value") + } +} + func writeConfigPbtxt(filename string, modelConfig *triton.ModelConfig) error { var err error @@ -39,3 +61,69 @@ func writeConfigPbtxt(filename string, modelConfig *triton.ModelConfig) error { } return nil } + +func convertKerasToTF(sourceModelIDDir string, ctx context.Context, loggr logr.Logger) error { + // check if keras and return the file name + kerasFile, err := checkAndReturnModelFile(sourceModelIDDir) + if err != nil { + return err + } + if kerasFile == "" { + //not a keras model + return nil + } + targetPath := filepath.Join(sourceModelIDDir, "model.savedmodel") + cmd := exec.Command("python", "/opt/scripts/tf_pb.py", kerasFile, targetPath) + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("Failed to create stdout pipe: %w ", err) + } + + if err = sem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("Failed to acquire semaphore for keras conversion process: %w", err) + } + defer sem.Release(1) + + if err = cmd.Start(); err != nil { + return fmt.Errorf("Failed to start python process for keras model conversion: %w ", err) + } + go copyOutput(stdout, loggr) + + err = cmd.Wait() + if exitErr, ok := err.(*exec.ExitError); ok && len(exitErr.Stderr) != 0 { + loggr.Error(err, "keras model conversion failed: %s", exitErr.Stderr) + return fmt.Errorf("keras model conversion failed: %s: %w", exitErr.Stderr, err) + } else if err != nil { + loggr.Error(err, "keras model conversion failed") + return fmt.Errorf("keras model conversion failed: %w", err) + } + return nil +} + +func copyOutput(r io.Reader, loggr logr.Logger) { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + loggr.Info(scanner.Text()) + } +} + +func checkAndReturnModelFile(sourceModelIDDir string) (string, error) { + files, err := ioutil.ReadDir(sourceModelIDDir) + if err != nil { + return "", fmt.Errorf("could not read files in dir %s: %w", sourceModelIDDir, err) + } + var modelFilePath string + var extFiles []string + for _, file := range files { + if filepath.Ext(file.Name()) == ".h5" { + modelFilePath = filepath.Join(sourceModelIDDir, file.Name()) + } else if file.Name() != "_schema.json" { + extFiles = append(extFiles, file.Name()) + } + } + if modelFilePath != "" && len(extFiles) != 0 { + return "", fmt.Errorf("model dir contains other files in addition to a keras model %s: %v", + modelFilePath, extFiles) + } + return modelFilePath, nil +}