Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions keras_hub/src/models/llama3/llama3_vision_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.Llama3VisionBackbone")
class Llama3VisionBackbone(Backbone):
"""Llama 3.2 Vision Backbone model.

This model combines a Vision Encoder (ViT) and a Llama 3 Text Decoder
interleaved with Gated Cross-Attention layers.

Args:
config: `Llama3VisionConfig` instance.
"""

def __init__(self, config, **kwargs):
# TODO(Vivek1106-04): Implement the Vision Encoder integration.
# This will initialize the vision tower and the text backbone.

# Placeholder for input validation
if config.vision_encoder_config is None:
raise ValueError("`vision_encoder_config` must be provided.")

super().__init__(**kwargs)
self.config = config

def call(self, inputs):
# TODO(Vivek1106-04): Implement the forward pass.
# 1. Process images through Vision Encoder.
# 2. Process text through Embedding.
# 3. Pass through Decoder layers with Cross-Attention.
return inputs

def get_config(self):
# serialization_lib requires a python dict, not a custom object
config = super().get_config()
config.update({"config": self.config.get_config()})
return config

@classmethod
def from_config(cls, config):
# We must manually deserialize the nested config object
from keras_hub.src.models.llama3.llama3_vision_config import (
Llama3VisionConfig,
)

config_data = config.pop("config")
vision_config = Llama3VisionConfig(**config_data)
return cls(config=vision_config, **config)
52 changes: 52 additions & 0 deletions keras_hub/src/models/llama3/llama3_vision_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone


@keras_hub_export("keras_hub.models.Llama3VisionConfig")
class Llama3VisionConfig(Llama3Backbone):
"""Configuration for the Llama 3.2 Vision Backbone.

This class extends `Llama3Backbone` to include parameters for the
vision encoder and the gated cross-attention mechanism used in
Llama 3.2 multimodal models (11B and 90B).

Args:
vision_encoder_config: dict or config instance. The configuration
for the vision encoder (ViT-like architecture).
vision_projection_dim: int. The dimension of the projection layer
that maps vision features to the text embedding space.
cross_attention_layers: list of int. The indices of the transformer
layers that should include gated cross-attention blocks.
For Llama 3.2 11B, this is typically every 4th layer.
**kwargs: Arguments for the parent `Llama3Backbone`.
"""

def __init__(
self,
vision_encoder_config=None,
vision_projection_dim=4096,
cross_attention_layers=None,
**kwargs,
):
super().__init__(**kwargs)
self.vision_encoder_config = vision_encoder_config
self.vision_projection_dim = vision_projection_dim
# Default to empty list if generic Llama3 is initialized without vision
self.cross_attention_layers = cross_attention_layers or []

def get_config(self):
config = super().get_config()

# specific fix for "vision_encoder_config is not JSON serializable"
vision_config_val = self.vision_encoder_config
if hasattr(vision_config_val, "get_config"):
vision_config_val = vision_config_val.get_config()

config.update(
{
"vision_encoder_config": vision_config_val,
"vision_projection_dim": self.vision_projection_dim,
"cross_attention_layers": self.cross_attention_layers,
}
)
return config
Loading