diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 85cde29b3c..a810f3430c 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from ) return loader.load_backbone(backbone_cls, load_weights, **kwargs) - def save_to_preset(self, preset_dir): + def save_to_preset(self, preset_dir, max_shard_size=10): """Save backbone to a preset directory. Args: preset_dir: The path to the local model preset directory. + max_shard_size: `int` or `float`. Maximum size in GB for each + sharded file. If `None`, no sharding will be done. Defaults to + `10`. """ saver = get_preset_saver(preset_dir) - saver.save_backbone(self) + saver.save_backbone(self, max_shard_size=max_shard_size) def get_lora_target_names(self): """Returns list of layer names which are to be LoRA-fied. diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 5920776232..d273759b46 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -236,14 +236,17 @@ def save_task_weights(self, filepath): objects_to_skip=backbone_layer_ids, ) - def save_to_preset(self, preset_dir): + def save_to_preset(self, preset_dir, max_shard_size=10): """Save task to a preset directory. Args: preset_dir: The path to the local model preset directory. + max_shard_size: `int` or `float`. Maximum size in GB for each + sharded file. If `None`, no sharding will be done. Defaults to + `10`. """ saver = get_preset_saver(preset_dir) - saver.save_task(self) + saver.save_task(self, max_shard_size=max_shard_size) @property def layers(self): diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index e1e40e489a..21607ffccb 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -1,3 +1,4 @@ +import inspect import sys import keras @@ -147,3 +148,13 @@ def get_gpu_names(): ] else: return [""] + + +def sharded_weights_available(): + """Whether sharded weights serialization is available. + + Returns: + `True` if sharded weights are available, `False` otherwise. + """ + save_weights_signature = inspect.signature(keras.saving.save_weights) + return "max_shard_size" in save_weights_signature.parameters diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 7fa4b3bb00..8423238b5c 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -10,6 +10,8 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.utils.keras_utils import print_msg +from keras_hub.src.utils.keras_utils import sharded_weights_available +from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits try: import kagglehub @@ -48,6 +50,7 @@ # Weight file names. MODEL_WEIGHTS_FILE = "model.weights.h5" TASK_WEIGHTS_FILE = "task.weights.h5" +SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json" # HuggingFace filenames. README_FILE = "README.md" @@ -647,7 +650,7 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone = self._load_serialized_object(self.config, **kwargs) if load_weights: jax_memory_cleanup(backbone) - backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) + self._load_backbone_weights(backbone) return backbone def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): @@ -697,8 +700,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): task.load_task_weights(task_weights) else: jax_memory_cleanup(task.backbone) - backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) - task.backbone.load_weights(backbone_weights) + self._load_backbone_weights(task.backbone) return task def load_preprocessor( @@ -726,18 +728,64 @@ def _load_serialized_object(self, config, **kwargs): config["config"] = {**config["config"], **kwargs} return keras.saving.deserialize_keras_object(config) + def _get_sharded_filenames(self, config_path): + with open(config_path, encoding="utf-8") as config_file: + config = json.load(config_file) + weight_map = config["weight_map"] + return sorted(set(weight_map.values())) + + def _load_backbone_weights(self, backbone): + # Detect if the backbone is sharded or not. + has_single_file_weights = check_file_exists( + self.preset, MODEL_WEIGHTS_FILE + ) + if has_single_file_weights: + filepath = get_file(self.preset, MODEL_WEIGHTS_FILE) + else: + if not sharded_weights_available(): + raise RuntimeError( + "Sharded weights loading is not supported in the current " + f"Keras version {keras.__version__}. " + "Please update to a newer version." + ) + filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE) + sharded_filenames = self._get_sharded_filenames(filepath) + for sharded_filename in sharded_filenames: + # Download the sharded weights. + _ = get_file(self.preset, sharded_filename) + backbone.load_weights(filepath) + class KerasPresetSaver: def __init__(self, preset_dir): os.makedirs(preset_dir, exist_ok=True) self.preset_dir = preset_dir - def save_backbone(self, backbone): + def save_backbone(self, backbone, max_shard_size=10): self._save_serialized_object(backbone, config_file=CONFIG_FILE) - backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) - backbone.save_weights(backbone_weight_path) self._save_metadata(backbone) + # Save the weights. + backbone_size_in_bytes = self._get_variables_size_in_bytes( + backbone.variables + ) + backbone_size_in_gb = backbone_size_in_bytes / (1024**3) + # If the size of the backbone is larger than `max_shard_size`, save + # sharded weights. + if sharded_weights_available() and backbone_size_in_gb > max_shard_size: + backbone_sharded_weights_config_path = os.path.join( + self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE + ) + backbone.save_weights( + backbone_sharded_weights_config_path, + max_shard_size=max_shard_size, + ) + else: + backbone_weight_path = os.path.join( + self.preset_dir, MODEL_WEIGHTS_FILE + ) + backbone.save_weights(backbone_weight_path) + def save_tokenizer(self, tokenizer): config_file = TOKENIZER_CONFIG_FILE if hasattr(tokenizer, "config_file"): @@ -755,7 +803,7 @@ def save_audio_converter(self, converter): def save_image_converter(self, converter): self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE) - def save_task(self, task): + def save_task(self, task, max_shard_size=10): # Save task specific config and weights. self._save_serialized_object(task, TASK_CONFIG_FILE) if task.has_task_weights(): @@ -763,10 +811,12 @@ def save_task(self, task): task.save_task_weights(task_weight_path) # Save backbone. if hasattr(task.backbone, "save_to_preset"): - task.backbone.save_to_preset(self.preset_dir) + task.backbone.save_to_preset( + self.preset_dir, max_shard_size=max_shard_size + ) else: # Allow saving a `keras.Model` that is not a backbone subclass. - self.save_backbone(task.backbone) + self.save_backbone(task.backbone, max_shard_size=max_shard_size) # Save preprocessor. if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"): task.preprocessor.save_to_preset(self.preset_dir) @@ -823,3 +873,13 @@ def _save_metadata(self, layer): metadata_path = os.path.join(self.preset_dir, METADATA_FILE) with open(metadata_path, "w") as metadata_file: metadata_file.write(json.dumps(metadata, indent=4)) + + def _get_variables_size_in_bytes(self, variables): + unique_variables = {} + for v in variables: + if id(v) not in unique_variables: + unique_variables[id(v)] = (v.shape, v.dtype) + total_memory_size = 0 + for shape, dtype in unique_variables.values(): + total_memory_size += get_tensor_size_in_bits(shape, dtype) + return total_memory_size / 8 diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 998dcadfa9..738682a286 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -10,12 +10,55 @@ ) from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.keras_utils import sharded_weights_available from keras_hub.src.utils.preset_utils import CONFIG_FILE from keras_hub.src.utils.preset_utils import upload_preset class PresetUtilsTest(TestCase): + @pytest.mark.large + def test_sharded_weights(self): + if not sharded_weights_available(): + self.skipTest("Sharded weights are not available.") + + init_kwargs = { + "vocabulary_size": 1024, + "num_layers": 12, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 32, + "intermediate_dim": 64, + "head_dim": 4, + "sliding_window_size": 5, + "attention_logit_soft_cap": 50, + "final_logit_soft_cap": 30, + "layer_norm_epsilon": 1e-6, + "query_head_dim_normalize": False, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "use_sliding_window_attention": True, + } + backbone = GemmaBackbone(**init_kwargs) # ~422KB + + # Save the sharded weights. + preset_dir = self.get_temp_dir() + backbone.save_to_preset(preset_dir, max_shard_size=0.0002) + self.assertTrue( + os.path.exists(os.path.join(preset_dir, "model.weights.json")) + ) + self.assertTrue( + os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5")) + ) + + # Load the sharded weights. + revived_backbone = GemmaBackbone.from_preset(preset_dir) + for v1, v2 in zip( + backbone.trainable_variables, revived_backbone.trainable_variables + ): + self.assertAllClose(v1, v2) + @pytest.mark.large def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "must be a string"): diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index a602963cf0..5588328ead 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -1,6 +1,8 @@ import contextlib import functools import inspect +import math +import re import threading import keras @@ -305,6 +307,29 @@ def is_string_dtype(dtype): return "string" in keras.backend.standardize_dtype(dtype) +def get_dtype_size_in_bits(dtype): + """Get the size of a given dtype in bits.""" + dtype = keras.backend.standardize_dtype(dtype) + # If dtype is bool, return 1 immediately. + if dtype == "bool": + return 1 + # Else, we extract the bit size from the string. + return int(re.sub(r"bfloat|float|uint|int", "", dtype)) + + +def get_tensor_size_in_bits(shape, dtype): + """Calculate the size given dtype and shape in bits. + + Args: + dtype: The dtype of the tensor. + shape: List of iterables representing the shape of the tensor. + + Returns: + The size of the tensor in bytes. + """ + return math.prod(shape) * get_dtype_size_in_bits(dtype) + + def any_equal(inputs, values, padding_mask): """Return a mask that is True anywhere `inputs` has a value in `values`. @@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask): Returns: A tensor with `inputs` shape where each position is True if it contains a value from any `values`. Padding mask will be applied before - returning.""" + returning. + """ output = ops.equal(inputs, values[0]) for value in values[1:]: value_equality = ops.equal(inputs, value)