From fa796eca2c075fbef92040bb4d44feef66a2f46e Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Mon, 15 Dec 2025 19:28:48 +0530 Subject: [PATCH 1/3] feat: Initial config and skeleton for llama 3.2 Vision --- .../models/llama3/llama3_vision_backbone.py | 37 +++++++++++++++ .../src/models/llama3/llama3_vision_config.py | 47 +++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 keras_hub/src/models/llama3/llama3_vision_backbone.py create mode 100644 keras_hub/src/models/llama3/llama3_vision_config.py diff --git a/keras_hub/src/models/llama3/llama3_vision_backbone.py b/keras_hub/src/models/llama3/llama3_vision_backbone.py new file mode 100644 index 0000000000..03446c9ce9 --- /dev/null +++ b/keras_hub/src/models/llama3/llama3_vision_backbone.py @@ -0,0 +1,37 @@ +import keras +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone + + +@keras_hub_export("keras_hub.models.Llama3VisionBackbone") +class Llama3VisionBackbone(Backbone): + """Llama 3.2 Vision Backbone model. + + This model combines a Vision Encoder (ViT) and a Llama 3 Text Decoder + interleaved with Gated Cross-Attention layers. + + Args: + config: `Llama3VisionConfig` instance. + """ + + def __init__(self, config, **kwargs): + # TODO(YourGitHubUsername): Implement the Vision Encoder integration. + # This will initialize the vision tower and the text backbone. + + # Placeholder for input validation + if config.vision_encoder_config is None: + raise ValueError("`vision_encoder_config` must be provided.") + + super().__init__(**kwargs) + self.config = config + + def call(self, inputs): + # TODO(YourGitHubUsername): Implement the forward pass. + # 1. Process images through Vision Encoder. + # 2. Process text through Embedding. + # 3. Pass through Decoder layers with Cross-Attention. + return inputs + + def get_config(self): + return {"config": self.config} \ No newline at end of file diff --git a/keras_hub/src/models/llama3/llama3_vision_config.py b/keras_hub/src/models/llama3/llama3_vision_config.py new file mode 100644 index 0000000000..38252604ce --- /dev/null +++ b/keras_hub/src/models/llama3/llama3_vision_config.py @@ -0,0 +1,47 @@ +import keras +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.llama3.llama3_backbone import Llama3BackboneConfig + + +@keras_hub_export("keras_hub.models.Llama3VisionConfig") +class Llama3VisionConfig(Llama3BackboneConfig): + """Configuration for the Llama 3.2 Vision Backbone. + + This class extends `Llama3BackboneConfig` to include parameters for the + vision encoder and the gated cross-attention mechanism used in + Llama 3.2 multimodal models (11B and 90B). + + Args: + vision_encoder_config: dict or config instance. The configuration + for the vision encoder (ViT-like architecture). + vision_projection_dim: int. The dimension of the projection layer + that maps vision features to the text embedding space. + cross_attention_layers: list of int. The indices of the transformer + layers that should include gated cross-attention blocks. + For Llama 3.2 11B, this is typically every 4th layer. + **kwargs: Arguments for the parent `Llama3BackboneConfig`. + """ + + def __init__( + self, + vision_encoder_config=None, + vision_projection_dim=4096, + cross_attention_layers=None, + **kwargs, + ): + super().__init__(**kwargs) + self.vision_encoder_config = vision_encoder_config + self.vision_projection_dim = vision_projection_dim + # Default to empty list if generic Llama3 is initialized without vision + self.cross_attention_layers = cross_attention_layers or [] + + def get_config(self): + config = super().get_config() + config.update( + { + "vision_encoder_config": self.vision_encoder_config, + "vision_projection_dim": self.vision_projection_dim, + "cross_attention_layers": self.cross_attention_layers, + } + ) + return config \ No newline at end of file From 20eaf561ba01c4467e66c8d59c8345d2f7ae5fde Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Tue, 16 Dec 2025 10:16:32 +0530 Subject: [PATCH 2/3] style: Apply linting and formatting fixes --- .../models/llama3/llama3_vision_backbone.py | 24 +++++++++++++++---- .../src/models/llama3/llama3_vision_config.py | 11 +++++++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/keras_hub/src/models/llama3/llama3_vision_backbone.py b/keras_hub/src/models/llama3/llama3_vision_backbone.py index 03446c9ce9..8851999834 100644 --- a/keras_hub/src/models/llama3/llama3_vision_backbone.py +++ b/keras_hub/src/models/llama3/llama3_vision_backbone.py @@ -1,4 +1,5 @@ import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone @@ -16,22 +17,35 @@ class Llama3VisionBackbone(Backbone): """ def __init__(self, config, **kwargs): - # TODO(YourGitHubUsername): Implement the Vision Encoder integration. + # TODO(Vivek1106-04): Implement the Vision Encoder integration. # This will initialize the vision tower and the text backbone. - + # Placeholder for input validation if config.vision_encoder_config is None: raise ValueError("`vision_encoder_config` must be provided.") super().__init__(**kwargs) self.config = config - + def call(self, inputs): - # TODO(YourGitHubUsername): Implement the forward pass. + # TODO(Vivek1106-04): Implement the forward pass. # 1. Process images through Vision Encoder. # 2. Process text through Embedding. # 3. Pass through Decoder layers with Cross-Attention. return inputs def get_config(self): - return {"config": self.config} \ No newline at end of file + # serialization_lib requires a python dict, not a custom object + config = super().get_config() + config.update({"config": self.config.get_config()}) + return config + + @classmethod + def from_config(cls, config): + # We must manually deserialize the nested config object + from keras_hub.src.models.llama3.llama3_vision_config import \ + Llama3VisionConfig + + config_data = config.pop("config") + vision_config = Llama3VisionConfig(**config_data) + return cls(config=vision_config, **config) diff --git a/keras_hub/src/models/llama3/llama3_vision_config.py b/keras_hub/src/models/llama3/llama3_vision_config.py index 38252604ce..cf505811ad 100644 --- a/keras_hub/src/models/llama3/llama3_vision_config.py +++ b/keras_hub/src/models/llama3/llama3_vision_config.py @@ -1,4 +1,5 @@ import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.llama3.llama3_backbone import Llama3BackboneConfig @@ -37,11 +38,17 @@ def __init__( def get_config(self): config = super().get_config() + + # specific fix for "vision_encoder_config is not JSON serializable" + vision_config_val = self.vision_encoder_config + if hasattr(vision_config_val, "get_config"): + vision_config_val = vision_config_val.get_config() + config.update( { - "vision_encoder_config": self.vision_encoder_config, + "vision_encoder_config": vision_config_val, "vision_projection_dim": self.vision_projection_dim, "cross_attention_layers": self.cross_attention_layers, } ) - return config \ No newline at end of file + return config From a211052a142de57fb7ccb6f7c5ee631817099b83 Mon Sep 17 00:00:00 2001 From: Vivek1106-04 Date: Tue, 16 Dec 2025 15:23:06 +0530 Subject: [PATCH 3/3] style: Apply linting and formatting fixes --- keras_hub/src/models/llama3/llama3_vision_backbone.py | 8 +++----- keras_hub/src/models/llama3/llama3_vision_config.py | 10 ++++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/keras_hub/src/models/llama3/llama3_vision_backbone.py b/keras_hub/src/models/llama3/llama3_vision_backbone.py index 8851999834..de2e847a1a 100644 --- a/keras_hub/src/models/llama3/llama3_vision_backbone.py +++ b/keras_hub/src/models/llama3/llama3_vision_backbone.py @@ -1,8 +1,5 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone @keras_hub_export("keras_hub.models.Llama3VisionBackbone") @@ -43,8 +40,9 @@ def get_config(self): @classmethod def from_config(cls, config): # We must manually deserialize the nested config object - from keras_hub.src.models.llama3.llama3_vision_config import \ - Llama3VisionConfig + from keras_hub.src.models.llama3.llama3_vision_config import ( + Llama3VisionConfig, + ) config_data = config.pop("config") vision_config = Llama3VisionConfig(**config_data) diff --git a/keras_hub/src/models/llama3/llama3_vision_config.py b/keras_hub/src/models/llama3/llama3_vision_config.py index cf505811ad..075494560b 100644 --- a/keras_hub/src/models/llama3/llama3_vision_config.py +++ b/keras_hub/src/models/llama3/llama3_vision_config.py @@ -1,14 +1,12 @@ -import keras - from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.llama3.llama3_backbone import Llama3BackboneConfig +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone @keras_hub_export("keras_hub.models.Llama3VisionConfig") -class Llama3VisionConfig(Llama3BackboneConfig): +class Llama3VisionConfig(Llama3Backbone): """Configuration for the Llama 3.2 Vision Backbone. - This class extends `Llama3BackboneConfig` to include parameters for the + This class extends `Llama3Backbone` to include parameters for the vision encoder and the gated cross-attention mechanism used in Llama 3.2 multimodal models (11B and 90B). @@ -20,7 +18,7 @@ class Llama3VisionConfig(Llama3BackboneConfig): cross_attention_layers: list of int. The indices of the transformer layers that should include gated cross-attention blocks. For Llama 3.2 11B, this is typically every 4th layer. - **kwargs: Arguments for the parent `Llama3BackboneConfig`. + **kwargs: Arguments for the parent `Llama3Backbone`. """ def __init__(