Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions keras_hub/src/models/llama3/llama3_vision_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone


@keras_hub_export("keras_hub.models.Llama3VisionBackbone")
class Llama3VisionBackbone(Backbone):
"""Llama 3.2 Vision Backbone model.

This model combines a Vision Encoder (ViT) and a Llama 3 Text Decoder
interleaved with Gated Cross-Attention layers.

Args:
config: `Llama3VisionConfig` instance.
"""

def __init__(self, config, **kwargs):
# TODO(YourGitHubUsername): Implement the Vision Encoder integration.
# This will initialize the vision tower and the text backbone.

# Placeholder for input validation
if config.vision_encoder_config is None:
raise ValueError("`vision_encoder_config` must be provided.")

super().__init__(**kwargs)
self.config = config

def call(self, inputs):
# TODO(YourGitHubUsername): Implement the forward pass.
# 1. Process images through Vision Encoder.
# 2. Process text through Embedding.
# 3. Pass through Decoder layers with Cross-Attention.
return inputs

def get_config(self):
return {"config": self.config}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_config method is returning the self.config object directly. This object is not serializable, which will cause errors when trying to save the model (e.g., with model.save()). The get_config method must return a JSON-serializable dictionary.

To fix this and align with Keras serialization patterns, you should:

  1. In get_config, call super().get_config() to include base properties like name and trainable, as recommended by the style guide.
  2. Serialize the config object by calling its get_config() method.
  3. Implement the from_config classmethod to correctly reconstruct the backbone from the serialized configuration. This is necessary because the default from_config from Backbone won't know how to handle the nested config object.

Additionally, the file is missing a final newline. Please run the code formatter to fix this.

Suggested change
def get_config(self):
return {"config": self.config}
def get_config(self):
config = super().get_config()
config["config"] = self.config.get_config()
return config
@classmethod
def from_config(cls, config):
from .llama3_vision_config import Llama3VisionConfig
config_data = config.pop("config")
vision_config = Llama3VisionConfig(**config_data)
return cls(config=vision_config, **config)
References
  1. The get_config() method should chain to its superclass to ensure base properties are preserved, as outlined in the layer implementation guidelines. (link)
  2. All components must be serializable to support saving and loading, which is a core part of the validation process. (link)
  3. Code should be formatted with ruff. A missing final newline is a formatting issue that the tool would fix. (link)

47 changes: 47 additions & 0 deletions keras_hub/src/models/llama3/llama3_vision_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.llama3.llama3_backbone import Llama3BackboneConfig


@keras_hub_export("keras_hub.models.Llama3VisionConfig")
class Llama3VisionConfig(Llama3BackboneConfig):
"""Configuration for the Llama 3.2 Vision Backbone.

This class extends `Llama3BackboneConfig` to include parameters for the
vision encoder and the gated cross-attention mechanism used in
Llama 3.2 multimodal models (11B and 90B).

Args:
vision_encoder_config: dict or config instance. The configuration
for the vision encoder (ViT-like architecture).
vision_projection_dim: int. The dimension of the projection layer
that maps vision features to the text embedding space.
cross_attention_layers: list of int. The indices of the transformer
layers that should include gated cross-attention blocks.
For Llama 3.2 11B, this is typically every 4th layer.
**kwargs: Arguments for the parent `Llama3BackboneConfig`.
"""

def __init__(
self,
vision_encoder_config=None,
vision_projection_dim=4096,
cross_attention_layers=None,
**kwargs,
):
super().__init__(**kwargs)
self.vision_encoder_config = vision_encoder_config
self.vision_projection_dim = vision_projection_dim
# Default to empty list if generic Llama3 is initialized without vision
self.cross_attention_layers = cross_attention_layers or []

def get_config(self):
config = super().get_config()
config.update(
{
"vision_encoder_config": self.vision_encoder_config,
"vision_projection_dim": self.vision_projection_dim,
"cross_attention_layers": self.cross_attention_layers,
}
)
return config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_config method directly includes self.vision_encoder_config. According to the docstring, this can be a config instance, which is not JSON-serializable. This will break serialization when Llama3VisionBackbone.get_config() is called. To ensure the entire configuration is serializable, you should convert self.vision_encoder_config to a dictionary if it's an object, for instance by calling a .get_config() method on it. This is crucial for model saving and loading.

Also, the file is missing a final newline. Please run the code formatter to add it.

References
  1. All components must be serializable to support saving and loading. Config objects must return a serializable dictionary from get_config. (link)
  2. Code should be formatted with ruff. A missing final newline is a formatting issue that the tool would fix. (link)

Loading