Skip to content

Commit 0434a42

Browse files
tjohnson31415njhill
authored andcommitted
feat/refactor(puller): update the schema_path in the processed request
* feat/refactor(puller): update the schema_path in the processed request And some clean up around also adding the model disk size And some minor refactors * chore: review comment and make fmt
1 parent dba7e38 commit 0434a42

File tree

5 files changed

+118
-84
lines changed

5 files changed

+118
-84
lines changed

model-serving-puller/puller/puller.go

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/go-logr/logr"
2626
"google.golang.org/grpc/status"
2727

28+
"github.com/kserve/modelmesh-runtime-adapter/internal/modelschema"
2829
"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
2930
"github.com/kserve/modelmesh-runtime-adapter/internal/util"
3031
)
@@ -69,16 +70,17 @@ func NewPullerFromConfig(log logr.Logger, config *PullerConfiguration) *Puller {
6970
return s
7071
}
7172

72-
// processLoadModelRequest is for use in an mmesh ModelRuntimeServer that embeds the puller
73+
// ProcessLoadModelRequest is for use in an mmesh serving runtime that embeds the puller
7374
//
74-
// The input request is modified in place and also returned. The path is
75-
// rewritten to a local file path and the size of the model on disk is added to
76-
// the model metadata.
75+
// The input request is modified in place and also returned.
76+
// After pulling the model files, changes to the request are:
77+
// - rewrite ModelPath to a local filesystem path
78+
// - rewrite ModelKey["schema_path"] to a local filesystem path
79+
// - add the size of the model on disk to ModelKey["disk_size_bytes"]
7780
func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.LoadModelRequest, error) {
7881
// parse json
7982
var modelKey map[string]interface{}
80-
parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey)
81-
if parseErr != nil {
83+
if parseErr := json.Unmarshal([]byte(req.ModelKey), &modelKey); parseErr != nil {
8284
return nil, fmt.Errorf("Invalid modelKey in LoadModelRequest. ModelKey value '%s' is not valid JSON: %s", req.ModelKey, parseErr)
8385
}
8486
schemaPath, ok := modelKey[jsonAttrModelSchemaPath].(string)
@@ -118,9 +120,32 @@ func (s *Puller) ProcessLoadModelRequest(req *mmesh.LoadModelRequest) (*mmesh.Lo
118120
if pullerErr != nil {
119121
return nil, status.Errorf(status.Code(pullerErr), "Failed to pull model from storage due to error: %s", pullerErr)
120122
}
121-
// update the request
123+
124+
// update the model path
122125
req.ModelPath = localPath
123-
req = AddModelDiskSize(req, s.Log)
126+
127+
// update the model key to add the schema path
128+
if schemaPath != "" {
129+
schemaFullPath, joinErr := util.SecureJoin(s.PullerConfig.RootModelDir, req.ModelId, modelschema.ModelSchemaFile)
130+
if joinErr != nil {
131+
return nil, fmt.Errorf("Error joining paths '%s', '%s', and '%s': %w", s.PullerConfig.RootModelDir, req.ModelId, modelschema.ModelSchemaFile, joinErr)
132+
}
133+
modelKey[jsonAttrModelSchemaPath] = schemaFullPath
134+
}
135+
136+
// update the model key to add the disk size
137+
if size, err1 := getModelDiskSize(localPath); err1 != nil {
138+
s.Log.Info("Model disk size will not be included in the LoadModelRequest due to error", "model_key", modelKey, "error", err1)
139+
} else {
140+
modelKey[jsonAttrModelKeyDiskSizeBytes] = size
141+
}
142+
143+
// rewrite the ModelKey JSON with any updates that have been made
144+
modelKeyBytes, err := json.Marshal(modelKey)
145+
if err != nil {
146+
return nil, fmt.Errorf("Error serializing ModelKey back to JSON: %w", err)
147+
}
148+
req.ModelKey = string(modelKeyBytes)
124149

125150
return req, nil
126151
}
@@ -158,19 +183,12 @@ func (s *Puller) CleanCache() {
158183
}
159184
}
160185

161-
func AddModelDiskSize(req *mmesh.LoadModelRequest, log logr.Logger) *mmesh.LoadModelRequest {
162-
var modelKey map[string]interface{}
163-
err := json.Unmarshal([]byte(req.ModelKey), &modelKey)
164-
if err != nil {
165-
log.Info("ModelDiskSize will not be included in the LoadModelRequest as LoadModelRequest.ModelKey value is not valid JSON", "size", jsonAttrModelKeyDiskSizeBytes, "model_key", req.ModelKey, "error", err)
166-
return req
167-
}
168-
186+
func getModelDiskSize(modelPath string) (int64, error) {
169187
// This walks the local filesystem and accumulates the size of the model
170188
// It would be more efficient to accumulate the size as the files are downloaded,
171189
// but this would require refactoring because the s3 download iterator does not return a size.
172190
var size int64
173-
err = filepath.Walk(req.ModelPath, func(_ string, info os.FileInfo, err error) error {
191+
err := filepath.Walk(modelPath, func(_ string, info os.FileInfo, err error) error {
174192
if err != nil {
175193
return err
176194
}
@@ -180,19 +198,10 @@ func AddModelDiskSize(req *mmesh.LoadModelRequest, log logr.Logger) *mmesh.LoadM
180198
return nil
181199
})
182200
if err != nil {
183-
log.Info("ModelDiskSize will not be included in the LoadModelRequest due to error getting the disk size", "size", jsonAttrModelKeyDiskSizeBytes, "path", req.ModelPath, "error", err)
184-
return req
201+
return size, fmt.Errorf("Error computing model's disk size: %w", err)
185202
}
186203

187-
modelKey[jsonAttrModelKeyDiskSizeBytes] = size
188-
modelKeyBytes, err := json.Marshal(modelKey)
189-
if err != nil {
190-
log.Info("ModelDiskSize will not be included in the LoadModelRequest as failure in marshalling to JSON", "size", jsonAttrModelKeyDiskSizeBytes, "model_key", modelKey, "error", err)
191-
return req
192-
}
193-
req.ModelKey = string(modelKeyBytes)
194-
195-
return req
204+
return size, nil
196205
}
197206

198207
func (p *Puller) CleanupModel(modelID string) error {

model-serving-puller/puller/puller_test.go

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ func Test_DownloadFromCOS_Success(t *testing.T) {
226226
assert.Nil(t, err)
227227
}
228228

229-
func Test_ProcessLoadModelRequest_Success(t *testing.T) {
229+
func Test_ProcessLoadModelRequest_Success_SingleFileModel(t *testing.T) {
230230
p, mockDownloader, mockCtrl := newPullerWithMock(t)
231231
defer mockCtrl.Finish()
232232

@@ -241,7 +241,7 @@ func Test_ProcessLoadModelRequest_Success(t *testing.T) {
241241
ModelId: "testmodel",
242242
ModelPath: filepath.Join(p.PullerConfig.RootModelDir, "testmodel/model.zip"),
243243
ModelType: "tensorflow",
244-
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage"}`,
244+
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage","storage_params":{"bucket":"bucket1"}}`,
245245
}
246246

247247
mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "model.zip").Return([]string{"model.zip"}, nil).Times(1)
@@ -252,6 +252,64 @@ func Test_ProcessLoadModelRequest_Success(t *testing.T) {
252252
assert.Nil(t, err)
253253
}
254254

255+
func Test_ProcessLoadModelRequest_Success_MultiFileModel(t *testing.T) {
256+
p, mockDownloader, mockCtrl := newPullerWithMock(t)
257+
defer mockCtrl.Finish()
258+
259+
request := &mmesh.LoadModelRequest{
260+
ModelId: "testmodel",
261+
ModelPath: "path/to/model",
262+
ModelType: "tensorflow",
263+
ModelKey: `{"storage_key": "myStorage", "bucket": "bucket1"}`,
264+
}
265+
266+
expectedRequestRewrite := &mmesh.LoadModelRequest{
267+
ModelId: "testmodel",
268+
ModelPath: filepath.Join(p.PullerConfig.RootModelDir, "testmodel"),
269+
ModelType: "tensorflow",
270+
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage","storage_params":{"bucket":"bucket1"}}`,
271+
}
272+
273+
mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "path/to/model").Return([]string{"path/to/model/model.zip", "path/to/model/metadata.txt", "path/to/model/model/data"}, nil).Times(1)
274+
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)
275+
276+
returnRequest, err := p.ProcessLoadModelRequest(request)
277+
assert.Equal(t, expectedRequestRewrite, returnRequest)
278+
assert.Nil(t, err)
279+
}
280+
281+
func Test_ProcessLoadModelRequest_SuccessWithSchema(t *testing.T) {
282+
p, mockDownloader, mockCtrl := newPullerWithMock(t)
283+
defer mockCtrl.Finish()
284+
285+
request := &mmesh.LoadModelRequest{
286+
ModelId: "testmodel",
287+
ModelPath: "model.zip",
288+
ModelType: "tensorflow",
289+
ModelKey: `{"storage_key": "myStorage", "bucket": "bucket1", "schema_path": "my_schema"}`,
290+
}
291+
292+
// expect updated schema_path in ModelKey
293+
expectedSchemaPath := filepath.Join(p.PullerConfig.RootModelDir, "testmodel/_schema.json")
294+
expectedRequestRewrite := &mmesh.LoadModelRequest{
295+
ModelId: "testmodel",
296+
ModelPath: filepath.Join(p.PullerConfig.RootModelDir, "testmodel/model.zip"),
297+
ModelType: "tensorflow",
298+
ModelKey: fmt.Sprintf(`{"bucket":"bucket1","disk_size_bytes":0,"schema_path":"%s","storage_key":"myStorage","storage_params":{"bucket":"bucket1"}}`, expectedSchemaPath),
299+
}
300+
301+
// model file
302+
mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "model.zip").Return([]string{"model.zip"}, nil).Times(1)
303+
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)
304+
// schema
305+
mockDownloader.EXPECT().ListObjectsUnderPrefix("bucket1", "my_schema").Return([]string{"my_schema"}, nil).Times(1)
306+
mockDownloader.EXPECT().DownloadWithIterator(gomock.Any(), gomock.Any()).Return(nil).Times(1)
307+
308+
returnRequest, err := p.ProcessLoadModelRequest(request)
309+
assert.Equal(t, expectedRequestRewrite, returnRequest)
310+
assert.Nil(t, err)
311+
}
312+
255313
func Test_ProcessLoadModelRequest_SuccessWithStorageParams(t *testing.T) {
256314
p, mockDownloader, mockCtrl := newPullerWithMock(t)
257315
defer mockCtrl.Finish()
@@ -328,3 +386,23 @@ func Test_ProcessLoadModelRequest_FailMissingStorageKey(t *testing.T) {
328386
assert.Nil(t, returnRequest)
329387
assert.EqualError(t, err, expectedError)
330388
}
389+
390+
func Test_getModelDiskSize(t *testing.T) {
391+
var diskSizeTests = []struct {
392+
modelPath string
393+
expectedSize int64
394+
}{
395+
{"testModelSize/1/airbnb.model.lr.zip", 15259},
396+
{"testModelSize/1", 15259},
397+
{"testModelSize/2", 39375276},
398+
}
399+
400+
for _, tt := range diskSizeTests {
401+
t.Run("", func(t *testing.T) {
402+
fullPath := filepath.Join(RootModelDir, tt.modelPath)
403+
diskSize, err := getModelDiskSize(fullPath)
404+
assert.NoError(t, err)
405+
assert.EqualValues(t, tt.expectedSize, diskSize)
406+
})
407+
}
408+
}

model-serving-puller/server/server.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,3 @@ func (s *PullerServer) RuntimeStatus(ctx context.Context, req *mmesh.RuntimeStat
230230
func (s *PullerServer) CleanCache() {
231231
s.puller.CleanCache()
232232
}
233-
234-
func addModelDiskSize(req *mmesh.LoadModelRequest, log logr.Logger) *mmesh.LoadModelRequest {
235-
return puller.AddModelDiskSize(req, log)
236-
}

model-serving-puller/server/server_test.go

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ package server
1515

1616
import (
1717
"context"
18-
"encoding/json"
1918
"path/filepath"
2019
"testing"
2120
"time"
@@ -27,7 +26,6 @@ import (
2726
"github.com/kserve/modelmesh-runtime-adapter/internal/proto/mmesh"
2827
"github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/generated/mocks"
2928
. "github.com/kserve/modelmesh-runtime-adapter/model-serving-puller/puller"
30-
"github.com/stretchr/testify/assert"
3129

3230
"sigs.k8s.io/controller-runtime/pkg/log/zap"
3331
)
@@ -101,7 +99,7 @@ func TestLoadModel(t *testing.T) {
10199
ModelId: tt.modelID,
102100
ModelPath: filepath.Join(s.puller.PullerConfig.RootModelDir, tt.outputModelPath),
103101
ModelType: "tensorflow",
104-
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage"}`,
102+
ModelKey: `{"bucket":"bucket1","disk_size_bytes":0,"storage_key":"myStorage","storage_params":{"bucket":"bucket1"}}`,
105103
}
106104

107105
// Assert s.LoadModel calls the s3 Download and then the model runtime LoadModel rpc
@@ -124,50 +122,3 @@ func TestLoadModel(t *testing.T) {
124122
})
125123
}
126124
}
127-
128-
func TestAddModelDiskSize(t *testing.T) {
129-
var diskSizeTests = []struct {
130-
modelPath string
131-
expectedSize int64
132-
}{
133-
{"testModelSize/1/airbnb.model.lr.zip", 15259},
134-
{"testModelSize/1", 15259},
135-
{"testModelSize/2", 39375276},
136-
}
137-
138-
for _, tt := range diskSizeTests {
139-
t.Run("", func(t *testing.T) {
140-
requestBefore := &mmesh.LoadModelRequest{
141-
ModelId: filepath.Base(filepath.Dir(tt.modelPath)),
142-
ModelPath: filepath.Join(RootModelDir, tt.modelPath),
143-
ModelType: "tensorflow",
144-
ModelKey: `{"storage_key": "myStorage", "bucket": "bucket1", "modelType": "tensorflow"}`,
145-
}
146-
var modelKeyBefore map[string]interface{}
147-
err := json.Unmarshal([]byte(requestBefore.ModelKey), &modelKeyBefore)
148-
if err != nil {
149-
t.Fatal("Error unmarshalling modelKeyBefore JSON", err)
150-
}
151-
assert.Equal(t, "myStorage", modelKeyBefore["storage_key"])
152-
assert.Equal(t, "bucket1", modelKeyBefore["bucket"])
153-
assert.Equal(t, "tensorflow", modelKeyBefore["modelType"])
154-
log := zap.New(zap.UseDevMode(true))
155-
requestAfter := addModelDiskSize(requestBefore, log)
156-
157-
assert.Equal(t, requestBefore.ModelId, requestAfter.ModelId)
158-
assert.Equal(t, requestBefore.ModelPath, requestAfter.ModelPath)
159-
assert.Equal(t, requestBefore.ModelType, requestAfter.ModelType)
160-
161-
var modelKeyAfter map[string]interface{}
162-
err = json.Unmarshal([]byte(requestAfter.ModelKey), &modelKeyAfter)
163-
if err != nil {
164-
t.Fatal("Error unmarshalling modelKeyAfter JSON", err)
165-
}
166-
167-
assert.Equal(t, modelKeyBefore["storage_key"], modelKeyAfter["storage_key"])
168-
assert.Equal(t, modelKeyBefore["bucket"], modelKeyAfter["bucket"])
169-
assert.Equal(t, modelKeyBefore["modelType"], modelKeyAfter["modelType"])
170-
assert.EqualValues(t, tt.expectedSize, modelKeyAfter["disk_size_bytes"])
171-
})
172-
}
173-
}

0 commit comments

Comments
 (0)