Skip to content

Commit 26b7e51

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Enable Vertex Model Garden Managed OSS Fine Tuning.
PiperOrigin-RevId: 834357700
1 parent ae1c2f2 commit 26b7e51

File tree

4 files changed

+139
-45
lines changed

4 files changed

+139
-45
lines changed

tests/unit/vertexai/tuning/test_tuning.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
import importlib
2323
from typing import Dict, Iterable
2424
from unittest import mock
25+
from unittest.mock import patch
2526
import uuid
2627

2728
from google import auth
2829
from google.auth import credentials as auth_credentials
30+
from google.cloud import storage
2931
from google.cloud import aiplatform
3032
import vertexai
3133
from google.cloud.aiplatform import compat
@@ -34,26 +36,16 @@
3436
from google.cloud.aiplatform.metadata import experiment_resources
3537
from google.cloud.aiplatform_v1beta1.services import gen_ai_tuning_service
3638
from google.cloud.aiplatform_v1beta1.types import job_state
37-
from google.cloud.aiplatform_v1beta1.types import (
38-
tuning_job as gca_tuning_job,
39-
)
39+
from google.cloud.aiplatform_v1beta1.types import tuning_job as gca_tuning_job
4040
from vertexai.preview import tuning
41-
from vertexai.preview.tuning import (
42-
sft as preview_supervised_tuning,
43-
)
44-
from vertexai.preview.tuning._tuning import SourceModel
41+
from vertexai.preview.tuning import sft as preview_supervised_tuning
42+
from vertexai.preview.tuning._tuning import SourceModel as PreviewSourceModel
43+
from vertexai.preview.tuning._tuning import TuningJob as PreviewTuningJob
4544
from vertexai.tuning import _distillation
4645
from vertexai.tuning import sft as supervised_tuning
47-
from google.cloud import storage
48-
from vertexai.preview.tuning._tuning import (
49-
TuningJob as PreviewTuningJob,
50-
)
51-
52-
46+
from vertexai.tuning._tuning import SourceModel
5347
import pytest
5448

55-
from unittest.mock import patch
56-
5749
from google.rpc import status_pb2
5850

5951

@@ -191,18 +183,18 @@ def teardown_method(self):
191183
initializer.global_pool.shutdown(wait=True)
192184

193185
@mock.patch.object(
194-
target=PreviewTuningJob,
186+
target=tuning.TuningJob,
195187
attribute="client_class",
196188
new=MockTuningJobClientWithOverride,
197189
)
198190
@pytest.mark.parametrize(
199191
"supervised_tuning",
200-
[preview_supervised_tuning],
192+
[preview_supervised_tuning, supervised_tuning],
201193
)
202194
def test_genai_tuning_service_supervised_tuning_tune_model(
203195
self, supervised_tuning: supervised_tuning
204196
):
205-
sft_tuning_job = supervised_tuning.preview_train(
197+
sft_tuning_job = supervised_tuning.train(
206198
source_model="gemini-1.0-pro-001",
207199
train_dataset="gs://some-bucket/some_dataset.jsonl",
208200
# Optional:
@@ -237,42 +229,42 @@ def test_genai_tuning_service_supervised_tuning_tune_model(
237229
assert sft_tuning_job.tuned_model_endpoint_name
238230

239231
@mock.patch.object(
240-
target=PreviewTuningJob,
232+
target=tuning.TuningJob,
241233
attribute="client_class",
242234
new=MockTuningJobClientWithOverride,
243235
)
244236
@pytest.mark.parametrize(
245237
"supervised_tuning",
246-
[preview_supervised_tuning],
238+
[supervised_tuning],
247239
)
248240
def test_genai_tuning_service_encryption_spec(
249241
self, supervised_tuning: supervised_tuning
250242
):
251243
"""Test that the global encryption spec propagates to the tuning job."""
252244
vertexai.init(encryption_spec_key_name="test-key")
253245

254-
sft_tuning_job = supervised_tuning.preview_train(
246+
sft_tuning_job = supervised_tuning.train(
255247
source_model="gemini-1.0-pro-001",
256248
train_dataset="gs://some-bucket/some_dataset.jsonl",
257249
)
258250
assert sft_tuning_job.encryption_spec.kms_key_name == "test-key"
259251

260252
@mock.patch.object(
261-
target=PreviewTuningJob,
253+
target=tuning.TuningJob,
262254
attribute="client_class",
263255
new=MockTuningJobClientWithOverride,
264256
)
265257
@pytest.mark.parametrize(
266258
"supervised_tuning",
267-
[preview_supervised_tuning],
259+
[supervised_tuning],
268260
)
269261
def test_genai_tuning_service_service_account(
270262
self, supervised_tuning: supervised_tuning
271263
):
272264
"""Test that the service account propagates to the tuning job."""
273265
vertexai.init(service_account="[email protected]")
274266

275-
sft_tuning_job = supervised_tuning.preview_train(
267+
sft_tuning_job = supervised_tuning.train(
276268
source_model="gemini-1.0-pro-002",
277269
train_dataset="gs://some-bucket/some_dataset.jsonl",
278270
)
@@ -331,19 +323,35 @@ def test_genai_tuning_service_distillation_distill_model(self):
331323
attribute="client_class",
332324
new=MockTuningJobClientWithOverride,
333325
)
326+
@mock.patch.object(
327+
target=tuning.TuningJob,
328+
attribute="client_class",
329+
new=MockTuningJobClientWithOverride,
330+
)
334331
@pytest.mark.parametrize(
335-
"supervised_tuning",
336-
[preview_supervised_tuning],
332+
"sft_train_method, source_model",
333+
[
334+
(
335+
preview_supervised_tuning.preview_train,
336+
PreviewSourceModel(
337+
base_model="meta/[email protected]",
338+
custom_base_model="gs://test-bucket/custom-weights",
339+
),
340+
),
341+
(
342+
supervised_tuning.train,
343+
SourceModel(
344+
base_model="meta/[email protected]",
345+
custom_base_model="gs://test-bucket/custom-weights",
346+
),
347+
),
348+
],
337349
)
338350
def test_create_tuning_job_success(
339-
self, supervised_tuning: preview_supervised_tuning
351+
self, sft_train_method: supervised_tuning.train, source_model: SourceModel
340352
):
341-
model = SourceModel(
342-
base_model="meta/[email protected]",
343-
custom_base_model="gs://test-bucket/custom-weights",
344-
)
345-
sft_tuning_job = supervised_tuning.preview_train(
346-
source_model=model,
353+
sft_tuning_job = sft_train_method(
354+
source_model=source_model,
347355
epochs=1,
348356
train_dataset="gs://test-bucket/test_train_dataset/",
349357
validation_dataset="gs://test-bucket/test_validation_dataset/",

vertexai/tuning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
# We just want to re-export certain classes
1818
# pylint: disable=g-multiple-import,g-importing-member
19+
from vertexai.tuning._tuning import SourceModel
1920
from vertexai.tuning._tuning import TuningJob
2021

2122
__all__ = [
23+
"SourceModel",
2224
"TuningJob",
2325
]

vertexai/tuning/_supervised_tuning.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,25 @@
2121
tuning_job as gca_tuning_job_types,
2222
)
2323
from vertexai import generative_models
24-
from vertexai.tuning import _tuning
24+
from vertexai.tuning import (
25+
SourceModel,
26+
TuningJob,
27+
)
2528

2629

2730
def train(
2831
*,
29-
source_model: Union[str, generative_models.GenerativeModel],
32+
source_model: Union[str, generative_models.GenerativeModel, SourceModel],
3033
train_dataset: Union[str, datasets.MultimodalDataset],
3134
validation_dataset: Optional[Union[str, datasets.MultimodalDataset]] = None,
3235
tuned_model_display_name: Optional[str] = None,
36+
tuning_mode: Optional[Literal["FULL", "PEFT_ADAPTER"]] = None,
3337
epochs: Optional[int] = None,
38+
learning_rate: Optional[float] = None,
3439
learning_rate_multiplier: Optional[float] = None,
3540
adapter_size: Optional[Literal[1, 4, 8, 16, 32]] = None,
3641
labels: Optional[Dict[str, str]] = None,
42+
output_uri: Optional[str] = None,
3743
) -> "SupervisedTuningJob":
3844
"""Tunes a model using supervised training.
3945
@@ -44,14 +50,41 @@ def train(
4450
tuned_model_display_name: The display name of the
4551
[TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to
4652
128 characters long and can consist of any UTF-8 characters.
53+
tuning_mode: Tuning mode for this tuning job. Can only be used with OSS
54+
models.
4755
epochs: Number of training epoches for this tuning job.
48-
learning_rate_multiplier: Learning rate multiplier for tuning.
56+
learning_rate: Learning rate for tuning. Can only be used with OSS
57+
models. Mutually exclusive with `learning_rate_multiplier`.
58+
learning_rate_multiplier: Learning rate multiplier for tuning. Can only
59+
be used with 1P models. Mutually exclusive with `learning_rate`.
4960
adapter_size: Adapter size for tuning.
5061
labels: User-defined metadata to be associated with trained models
62+
output_uri: The Google Cloud Storage URI to write the tuned model to.
63+
Can only be used with OSS models.
5164
5265
Returns:
5366
A `TuningJob` object.
5467
"""
68+
if tuning_mode is None:
69+
tuning_mode_value = None
70+
elif tuning_mode == "FULL":
71+
tuning_mode_value = (
72+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_FULL
73+
)
74+
elif tuning_mode == "PEFT_ADAPTER":
75+
tuning_mode_value = (
76+
gca_tuning_job_types.SupervisedTuningSpec.TuningMode.TUNING_MODE_PEFT_ADAPTER
77+
)
78+
else:
79+
raise ValueError(
80+
f"Unsupported tuning mode: {tuning_mode}. The supported tuning modes are [FULL, PEFT_ADAPTER]"
81+
)
82+
83+
if learning_rate and learning_rate_multiplier:
84+
raise ValueError(
85+
"Only one of `learning_rate` and `learning_rate_multiplier` can be set."
86+
)
87+
5588
if adapter_size is None:
5689
adapter_size_value = None
5790
elif adapter_size == 1:
@@ -83,10 +116,12 @@ def train(
83116
if isinstance(validation_dataset, datasets.MultimodalDataset):
84117
validation_dataset = validation_dataset.resource_name
85118
supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec(
119+
tuning_mode=tuning_mode_value,
86120
training_dataset_uri=train_dataset,
87121
validation_dataset_uri=validation_dataset,
88122
hyper_parameters=gca_tuning_job_types.SupervisedHyperParameters(
89123
epoch_count=epochs,
124+
learning_rate=learning_rate,
90125
learning_rate_multiplier=learning_rate_multiplier,
91126
adapter_size=adapter_size_value,
92127
),
@@ -95,20 +130,26 @@ def train(
95130
if isinstance(source_model, generative_models.GenerativeModel):
96131
source_model = source_model._prediction_resource_name.rpartition("/")[-1]
97132

133+
if labels is None:
134+
labels = {}
135+
if "mg-source" not in labels and output_uri:
136+
labels["mg-source"] = "sdk"
137+
98138
supervised_tuning_job = (
99139
SupervisedTuningJob._create( # pylint: disable=protected-access
100140
base_model=source_model,
101141
tuning_spec=supervised_tuning_spec,
102142
tuned_model_display_name=tuned_model_display_name,
103143
labels=labels,
144+
output_uri=output_uri,
104145
)
105146
)
106147
_ipython_utils.display_model_tuning_button(supervised_tuning_job)
107148

108149
return supervised_tuning_job
109150

110151

111-
class SupervisedTuningJob(_tuning.TuningJob):
152+
class SupervisedTuningJob(TuningJob):
112153
def __init__(self, tuning_job_name: str):
113154
super().__init__(tuning_job_name=tuning_job_name)
114155
_ipython_utils.display_model_tuning_button(self)

vertexai/tuning/_tuning.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@
4343
_LOGGER = aiplatform_base.Logger(__name__)
4444

4545

46+
class SourceModel:
47+
r"""A model that is used in managed OSS supervised tuning.
48+
49+
Usage:
50+
```
51+
model = SourceModel(
52+
base_model="meta/[email protected]",
53+
custom_base_model="gs://user-bucket/custom-weights",
54+
)
55+
sft_tuning_job = sft.train(
56+
source_model=model,
57+
train_dataset="gs://my-bucket/train.jsonl",
58+
validation_dataset="gs://my-bucket/validation.jsonl",
59+
epochs=4,
60+
tuned_model_display_name="my-tuned-model",
61+
output_uri="gs://user-bucket/tuned-model"
62+
)
63+
64+
while not sft_tuning_job.has_ended:
65+
time.sleep(60)
66+
sft_tuning_job.refresh()
67+
68+
tuned_model = aiplatform.Model(sft_tuning_job.tuned_model_name)
69+
```
70+
"""
71+
72+
def __init__(
73+
self,
74+
base_model: str,
75+
custom_base_model: str = "",
76+
):
77+
r"""Initializes SourceModel."""
78+
self.base_model = base_model
79+
self.custom_base_model = custom_base_model
80+
81+
4682
class TuningJobClientWithOverride(aiplatform_utils.ClientWithOverride):
4783
_is_temporary = True
4884
_default_version = compat.V1BETA1
@@ -133,7 +169,7 @@ def tuning_data_statistics(self) -> gca_tuning_job_types.TuningDataStats:
133169
def _create(
134170
cls,
135171
*,
136-
base_model: str,
172+
base_model: Union[str, SourceModel],
137173
tuning_spec: Union[
138174
gca_tuning_job_types.SupervisedTuningSpec,
139175
gca_tuning_job_types.DistillationSpec,
@@ -144,15 +180,13 @@ def _create(
144180
project: Optional[str] = None,
145181
location: Optional[str] = None,
146182
credentials: Optional[auth_credentials.Credentials] = None,
183+
output_uri: Optional[str] = None,
147184
) -> "TuningJob":
148185
r"""Submits TuningJob.
149186
150187
Args:
151-
base_model (str):
152-
Model name for tuning, e.g., "gemini-1.0-pro"
153-
or "gemini-1.0-pro-001".
154-
155-
This field is a member of `oneof`_ ``source_model``.
188+
base_model: Model for tuning.
189+
Supported types: str, SourceModel.
156190
tuning_spec: Tuning Spec for Fine Tuning.
157191
Supported types: SupervisedTuningSpec, DistillationSpec.
158192
tuned_model_display_name: The display name of the
@@ -179,6 +213,7 @@ def _create(
179213
Overrides location set in aiplatform.init.
180214
credentials: Custom credentials to use to call tuning job service.
181215
Overrides credentials set in aiplatform.init.
216+
output_uri: The Google Cloud Storage location to write the artifacts. This is only used for OSS models.
182217
183218
Returns:
184219
Submitted TuningJob.
@@ -192,17 +227,25 @@ def _create(
192227
tuned_model_display_name = cls._generate_display_name()
193228

194229
gca_tuning_job = gca_tuning_job_types.TuningJob(
195-
base_model=base_model,
196230
tuned_model_display_name=tuned_model_display_name,
197231
description=description,
198232
labels=labels,
199-
# The tuning_spec one_of is set later
233+
# The tuning_spec one_of is set later.
234+
output_uri=output_uri,
200235
)
201236

202237
if isinstance(tuning_spec, gca_tuning_job_types.SupervisedTuningSpec):
203238
gca_tuning_job.supervised_tuning_spec = tuning_spec
239+
if isinstance(base_model, SourceModel):
240+
gca_tuning_job.base_model = base_model.base_model
241+
gca_tuning_job.custom_base_model = base_model.custom_base_model
242+
else:
243+
gca_tuning_job.base_model = base_model
204244
elif isinstance(tuning_spec, gca_tuning_job_types.DistillationSpec):
205245
gca_tuning_job.distillation_spec = tuning_spec
246+
if isinstance(base_model, SourceModel):
247+
raise RuntimeError("Distillation is not supported for custom models.")
248+
gca_tuning_job.base_model = base_model
206249
else:
207250
raise RuntimeError(f"Unsupported tuning_spec kind: {tuning_spec}")
208251

0 commit comments

Comments
 (0)