diff --git a/keras_hub/src/models/llama3/llama3_vision_backbone.py b/keras_hub/src/models/llama3/llama3_vision_backbone.py new file mode 100644 index 0000000000..de2e847a1a --- /dev/null +++ b/keras_hub/src/models/llama3/llama3_vision_backbone.py @@ -0,0 +1,49 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone + + +@keras_hub_export("keras_hub.models.Llama3VisionBackbone") +class Llama3VisionBackbone(Backbone): + """Llama 3.2 Vision Backbone model. + + This model combines a Vision Encoder (ViT) and a Llama 3 Text Decoder + interleaved with Gated Cross-Attention layers. + + Args: + config: `Llama3VisionConfig` instance. + """ + + def __init__(self, config, **kwargs): + # TODO(Vivek1106-04): Implement the Vision Encoder integration. + # This will initialize the vision tower and the text backbone. + + # Placeholder for input validation + if config.vision_encoder_config is None: + raise ValueError("`vision_encoder_config` must be provided.") + + super().__init__(**kwargs) + self.config = config + + def call(self, inputs): + # TODO(Vivek1106-04): Implement the forward pass. + # 1. Process images through Vision Encoder. + # 2. Process text through Embedding. + # 3. Pass through Decoder layers with Cross-Attention. + return inputs + + def get_config(self): + # serialization_lib requires a python dict, not a custom object + config = super().get_config() + config.update({"config": self.config.get_config()}) + return config + + @classmethod + def from_config(cls, config): + # We must manually deserialize the nested config object + from keras_hub.src.models.llama3.llama3_vision_config import ( + Llama3VisionConfig, + ) + + config_data = config.pop("config") + vision_config = Llama3VisionConfig(**config_data) + return cls(config=vision_config, **config) diff --git a/keras_hub/src/models/llama3/llama3_vision_config.py b/keras_hub/src/models/llama3/llama3_vision_config.py new file mode 100644 index 0000000000..075494560b --- /dev/null +++ b/keras_hub/src/models/llama3/llama3_vision_config.py @@ -0,0 +1,52 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone + + +@keras_hub_export("keras_hub.models.Llama3VisionConfig") +class Llama3VisionConfig(Llama3Backbone): + """Configuration for the Llama 3.2 Vision Backbone. + + This class extends `Llama3Backbone` to include parameters for the + vision encoder and the gated cross-attention mechanism used in + Llama 3.2 multimodal models (11B and 90B). + + Args: + vision_encoder_config: dict or config instance. The configuration + for the vision encoder (ViT-like architecture). + vision_projection_dim: int. The dimension of the projection layer + that maps vision features to the text embedding space. + cross_attention_layers: list of int. The indices of the transformer + layers that should include gated cross-attention blocks. + For Llama 3.2 11B, this is typically every 4th layer. + **kwargs: Arguments for the parent `Llama3Backbone`. + """ + + def __init__( + self, + vision_encoder_config=None, + vision_projection_dim=4096, + cross_attention_layers=None, + **kwargs, + ): + super().__init__(**kwargs) + self.vision_encoder_config = vision_encoder_config + self.vision_projection_dim = vision_projection_dim + # Default to empty list if generic Llama3 is initialized without vision + self.cross_attention_layers = cross_attention_layers or [] + + def get_config(self): + config = super().get_config() + + # specific fix for "vision_encoder_config is not JSON serializable" + vision_config_val = self.vision_encoder_config + if hasattr(vision_config_val, "get_config"): + vision_config_val = vision_config_val.get_config() + + config.update( + { + "vision_encoder_config": vision_config_val, + "vision_projection_dim": self.vision_projection_dim, + "cross_attention_layers": self.cross_attention_layers, + } + ) + return config