Skip to content

Commit

Permalink
chore: Clearer error message when predictor storage missing and refac…
Browse files Browse the repository at this point in the history
…toring

* chore: clearer error message  when predictor storage missing

* chore: fix error message

* refactor: move parsing json out

* chore: separate storage params from s3 params

* refactor: move schema path out of downloadFromCos

* refactor: create map for storage params in model key map

* syntax: fix misspelling in var

* chore: remove unneeded assignment and instantiate map

* chore: fix assignment of variable

* test: add some unit tests
  • Loading branch information
anhuong authored and kserve-oss-bot committed Aug 27, 2021
1 parent a9795c4 commit 459a917
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 36 deletions.
43 changes: 41 additions & 2 deletions model-serving-puller/puller/puller.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import (
const jsonAttrModelKeyStorageKey = "storage_key"
const jsonAttrModelKeyBucket = "bucket"
const jsonAttrModelKeyDiskSizeBytes = "disk_size_bytes"
const jsonAtrrModelSchemaPath = "schema_path"
const jsonAttrModelSchemaPath = "schema_path"
const jsonAttrStorageParams = "storage_params"

// Puller represents the GRPC server and its configuration
type Puller struct {
Expand Down Expand Up @@ -74,8 +75,46 @@ func NewPullerFromConfig(log logr.Logger, config *PullerConfiguration) *Puller {
// rewritten to a local file path and the size of the model on disk is added to
// the model metadata.
func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
// parse json
var modelKey map[string]interface{}
parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey)
if parseErr != nil {
return nil, fmt.Errorf("Invalid modelKey in LoadModelRequest. ModelKey value '%s' is not valid JSON: %s", req.ModelKey, parseErr)
}
schemaPath, ok := modelKey[jsonAttrModelSchemaPath].(string)
if !ok {
if modelKey[jsonAttrModelSchemaPath] != nil {
return nil, fmt.Errorf("Invalid schemaPath in LoadModelRequest, '%s' attribute must have a string value. Found value %v", jsonAttrModelSchemaPath, modelKey[jsonAttrModelSchemaPath])
}
}
storageKey, ok := modelKey[jsonAttrModelKeyStorageKey].(string)
if !ok {
return nil, fmt.Errorf("Predictor Storage field missing")
}

// get storage config
storageConfig, err := s.PullerConfig.GetStorageConfiguration(storageKey, s.Log)
if err != nil {
return nil, err
}

// get storage params
storageParams, ok := modelKey[jsonAttrStorageParams].(map[string]interface{})
if !ok {
// backwards compatability: if storage_params does not exist
bucketName, ok := modelKey[jsonAttrModelKeyBucket].(string)
if !ok {
if modelKey[jsonAttrModelKeyBucket] != nil {
return nil, fmt.Errorf("Invalid modelKey in LoadModelRequest, '%s' attribute must have a string value. Found value %v", jsonAttrModelKeyBucket, modelKey[jsonAttrModelKeyBucket])
}
}
storageParams = make(map[string]interface{})
storageParams[jsonAttrModelKeyBucket] = bucketName
modelKey[jsonAttrStorageParams] = storageParams
}

// download the model
localPath, pullerErr := s.DownloadFromCOS(req)
localPath, pullerErr := s.DownloadFromCOS(req.ModelId, req.ModelPath, schemaPath, storageKey, storageConfig, storageParams)
if pullerErr != nil {
return nil, status.Errorf(status.Code(pullerErr), "Failed to pull model from storage due to error: %s", pullerErr)
}
Expand Down
134 changes: 134 additions & 0 deletions model-serving-puller/puller/puller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"

"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
"github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/generated/mocks"

"sigs.k8s.io/controller-runtime/pkg/log/zap"
Expand Down Expand Up @@ -194,3 +195,136 @@ func Test_CleanCache_DeletesFakeKeys(t *testing.T) {
p.CleanCache()
assert.Equal(t, 2, len(p.s3DownloaderCache))
}

func Test_DownloadFromCOS_ErrorBucketDoesNotExist(t *testing.T) {
expectedError := "Storage bucket was not specified in the LoadModel request and there is no default bucket in the storage configuration"

p, _, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

bucketParam := make(map[string]interface{})
downloader, err := p.DownloadFromCOS("modelID", "object/path", "", "storageKey", &StorageConfigTest, bucketParam)
assert.Equal(t, "", downloader)
assert.EqualError(t, err, expectedError)
}

func Test_DownloadFromCOS_Success(t *testing.T) {
objectPath := "myPath"
modelId := "myModelID"
bucket := "bucket1"
p, mockDownloader, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

expectedPath := filepath.Join(p.PullerConfig.RootModelDir, modelId, objectPath)

mockDownloader.EXPECT().ListObjectsUnderPrefix(bucket, objectPath).Return([]string{objectPath}, nil).Times(1)
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)

bucketParam := map[string]interface{}{"bucket": bucket}
path, err := p.DownloadFromCOS(modelId, objectPath, "", StorageKeyTest, &StorageConfigTest, bucketParam)
assert.Equal(t, expectedPath, path)
assert.Nil(t, err)
}

func Test_ProcessLoadModelRequest_Success(t *testing.T) {
p, mockDownloader, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

request := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: "model.zip",
ModelType: "tensorflow",
ModelKey: `{"storage_key": "myStorage", "bucket": "bucket1"}`,
}

expectedRequestRewrite := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: filepath.Join(p.PullerConfig.RootModelDir, "testmodel/model.zip"),
ModelType: "tensorflow",
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage"}`,
}

mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "model.zip").Return([]string{"model.zip"}, nil).Times(1)
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
assert.Equal(t, expectedRequestRewrite, returnRequest)
assert.Nil(t, err)
}

func Test_ProcessLoadModelRequest_SuccessWithStorageParams(t *testing.T) {
p, mockDownloader, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

request := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: "model.zip",
ModelType: "tensorflow",
ModelKey: `{"storage_params":{"bucket":"bucket1"}, "storage_key": "myStorage"}`,
}

expectedRequestRewrite := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: filepath.Join(p.PullerConfig.RootModelDir, "testmodel/model.zip"),
ModelType: "tensorflow",
ModelKey: `{"disk_size_bytes":0,"storage_key":"myStorage","storage_params":{"bucket":"bucket1"}}`,
}

mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "model.zip").Return([]string{"model.zip"}, nil).Times(1)
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)

returnRequest, err := p.ProcessLoadModelRequest(request)
assert.Equal(t, expectedRequestRewrite, returnRequest)
assert.Nil(t, err)
}

func Test_ProcessLoadModelRequest_FailInvalidModelKey(t *testing.T) {
request := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: "model.zip",
ModelType: "tensorflow",
ModelKey: `{}{"storage_params":{"bucket":"bucket1"}, "storage_key": "myStorage"}`,
}
expectedError := fmt.Sprintf("Invalid modelKey in LoadModelRequest. ModelKey value '%s' is not valid JSON", request.ModelKey)

p, _, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

returnRequest, err := p.ProcessLoadModelRequest(request)
assert.Nil(t, returnRequest)
assert.Contains(t, err.Error(), expectedError)
}

func Test_ProcessLoadModelRequest_FailInvalidSchemaPath(t *testing.T) {
request := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: "model.zip",
ModelType: "tensorflow",
ModelKey: `{"storage_params":{"bucket":"bucket1"}, "storage_key": "myStorage", "schema_path": 2}`,
}
expectedError := "Invalid schemaPath in LoadModelRequest, 'schema_path' attribute must have a string value. Found value 2"

p, _, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

returnRequest, err := p.ProcessLoadModelRequest(request)
assert.Nil(t, returnRequest)
assert.EqualError(t, err, expectedError)
}

func Test_ProcessLoadModelRequest_FailMissingStorageKey(t *testing.T) {
request := &mmesh.LoadModelRequest{
ModelId: "testmodel",
ModelPath: "model.zip",
ModelType: "tensorflow",
ModelKey: `{"storage_params":{"bucket":"bucket1"}}`,
}
expectedError := "Predictor Storage field missing"

p, _, mockCtrl := newPullerWithMock(t)
defer mockCtrl.Finish()

returnRequest, err := p.ProcessLoadModelRequest(request)
assert.Nil(t, returnRequest)
assert.EqualError(t, err, expectedError)
}
40 changes: 6 additions & 34 deletions model-serving-puller/puller/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package puller

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
Expand All @@ -30,7 +29,6 @@ import (
"github.com/IBM/ibm-cos-sdk-go/service/s3"
"github.com/IBM/ibm-cos-sdk-go/service/s3/s3manager"
"github.com/kserve/modelmesh-runtime-adapter/internal/envconfig"
"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
"github.com/kserve/modelmesh-runtime-adapter/util"
)

Expand Down Expand Up @@ -120,28 +118,11 @@ func NewS3Downloader(config *StorageConfiguration, downloadConcurrency int, log
return &s3Downloader{downloader: downloader, client: s3Client, config: config, Log: log}, nil
}

func (s *Puller) DownloadFromCOS(req *mmesh.LoadModelRequest) (string, error) {
modelID := req.ModelId
objPath := req.ModelPath

var modelKey map[string]interface{}
parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey)
if parseErr != nil {
return "", fmt.Errorf("Invalid modelKey in LoadModelRequest. ModelKey value '%s' is not valid JSON: %s", req.ModelKey, parseErr)
}
storageKey, ok := modelKey[jsonAttrModelKeyStorageKey].(string)
if !ok {
return "", fmt.Errorf("Invalid modelKey in LoadModelRequest, '%s' attribute must exist and have a string value. Found value %v", jsonAttrModelKeyStorageKey, modelKey[jsonAttrModelKeyStorageKey])
}
storageConfig, err := s.PullerConfig.GetStorageConfiguration(storageKey, s.Log)
if err != nil {
return "", err
}

bucketName, ok := modelKey[jsonAttrModelKeyBucket].(string)
func (s *Puller) DownloadFromCOS(modelID string, objPath string, schemaPath string, storageKey string, storageConfig *StorageConfiguration, s3Params map[string]interface{}) (string, error) {
bucketName, ok := s3Params[jsonAttrModelKeyBucket].(string)
if !ok {
if modelKey[jsonAttrModelKeyBucket] != nil {
return "", fmt.Errorf("Invalid modelKey in LoadModelRequest, '%s' attribute must have a string value. Found value %v", jsonAttrModelKeyBucket, modelKey[jsonAttrModelKeyBucket])
if s3Params[jsonAttrModelKeyBucket] != nil {
return "", fmt.Errorf("Invalid storageParams in LoadModelRequest, '%s' attribute must have a string value. Found value %v", jsonAttrModelKeyBucket, s3Params[jsonAttrModelKeyBucket])
}
// no bucket attribute specified, fall down to the default
bucketName = ""
Expand All @@ -155,15 +136,6 @@ func (s *Puller) DownloadFromCOS(req *mmesh.LoadModelRequest) (string, error) {
return "", fmt.Errorf("Storage bucket was not specified in the LoadModel request and there is no default bucket in the storage configuration")
}

schemaPath, ok := modelKey[jsonAtrrModelSchemaPath].(string)
if !ok {
if modelKey[jsonAtrrModelSchemaPath] != nil {
return "", fmt.Errorf("Invalid schemaPath in LoadModelRequest, '%s' attribute must have a string value. Found value %v", jsonAtrrModelSchemaPath, modelKey[jsonAtrrModelSchemaPath])
}
// no schemaPath attribute specified, fall down to the default
schemaPath = ""
}

downloader, err := s.getS3Downloader(storageKey, storageConfig)
if err != nil {
return "", err
Expand Down Expand Up @@ -204,7 +176,7 @@ func (s *Puller) DownloadFromCOS(req *mmesh.LoadModelRequest) (string, error) {
}

if len(schemaToDownload) == 1 {
schemaStorage := storage{storagekey: storageKey, bucketname: bucketName, path: modelKey[jsonAtrrModelSchemaPath].(string)}
schemaStorage := storage{storagekey: storageKey, bucketname: bucketName, path: schemaPath}
p, serr := s.DownloadObjectsfromCOS(modelID, &schemaStorage, schemaToDownload, downloader)

if serr != nil {
Expand All @@ -230,7 +202,7 @@ func (s *Puller) DownloadFromCOS(req *mmesh.LoadModelRequest) (string, error) {
}
}

modelStorage := storage{storagekey: storageKey, bucketname: bucketName, path: req.ModelPath}
modelStorage := storage{storagekey: storageKey, bucketname: bucketName, path: objPath}
p, err := s.DownloadObjectsfromCOS(modelID, &modelStorage, objectsToDownload, downloader)

if err != nil {
Expand Down

0 comments on commit 459a917

Please sign in to comment.