-
Notifications
You must be signed in to change notification settings - Fork 312
feat: Initial config and skeleton for llama 3.2 Vision #2472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| import keras | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.backbone import Backbone | ||
| from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.Llama3VisionBackbone") | ||
| class Llama3VisionBackbone(Backbone): | ||
| """Llama 3.2 Vision Backbone model. | ||
|
|
||
| This model combines a Vision Encoder (ViT) and a Llama 3 Text Decoder | ||
| interleaved with Gated Cross-Attention layers. | ||
|
|
||
| Args: | ||
| config: `Llama3VisionConfig` instance. | ||
| """ | ||
|
|
||
| def __init__(self, config, **kwargs): | ||
| # TODO(YourGitHubUsername): Implement the Vision Encoder integration. | ||
| # This will initialize the vision tower and the text backbone. | ||
|
|
||
| # Placeholder for input validation | ||
| if config.vision_encoder_config is None: | ||
| raise ValueError("`vision_encoder_config` must be provided.") | ||
|
|
||
| super().__init__(**kwargs) | ||
| self.config = config | ||
|
|
||
| def call(self, inputs): | ||
| # TODO(YourGitHubUsername): Implement the forward pass. | ||
| # 1. Process images through Vision Encoder. | ||
| # 2. Process text through Embedding. | ||
| # 3. Pass through Decoder layers with Cross-Attention. | ||
| return inputs | ||
|
|
||
| def get_config(self): | ||
| return {"config": self.config} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import keras | ||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.llama3.llama3_backbone import Llama3BackboneConfig | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.Llama3VisionConfig") | ||
| class Llama3VisionConfig(Llama3BackboneConfig): | ||
| """Configuration for the Llama 3.2 Vision Backbone. | ||
|
|
||
| This class extends `Llama3BackboneConfig` to include parameters for the | ||
| vision encoder and the gated cross-attention mechanism used in | ||
| Llama 3.2 multimodal models (11B and 90B). | ||
|
|
||
| Args: | ||
| vision_encoder_config: dict or config instance. The configuration | ||
| for the vision encoder (ViT-like architecture). | ||
| vision_projection_dim: int. The dimension of the projection layer | ||
| that maps vision features to the text embedding space. | ||
| cross_attention_layers: list of int. The indices of the transformer | ||
| layers that should include gated cross-attention blocks. | ||
| For Llama 3.2 11B, this is typically every 4th layer. | ||
| **kwargs: Arguments for the parent `Llama3BackboneConfig`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vision_encoder_config=None, | ||
| vision_projection_dim=4096, | ||
| cross_attention_layers=None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.vision_encoder_config = vision_encoder_config | ||
| self.vision_projection_dim = vision_projection_dim | ||
| # Default to empty list if generic Llama3 is initialized without vision | ||
| self.cross_attention_layers = cross_attention_layers or [] | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "vision_encoder_config": self.vision_encoder_config, | ||
| "vision_projection_dim": self.vision_projection_dim, | ||
| "cross_attention_layers": self.cross_attention_layers, | ||
| } | ||
| ) | ||
| return config | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
get_configmethod is returning theself.configobject directly. This object is not serializable, which will cause errors when trying to save the model (e.g., withmodel.save()). Theget_configmethod must return a JSON-serializable dictionary.To fix this and align with Keras serialization patterns, you should:
get_config, callsuper().get_config()to include base properties likenameandtrainable, as recommended by the style guide.configobject by calling itsget_config()method.from_configclassmethod to correctly reconstruct the backbone from the serialized configuration. This is necessary because the defaultfrom_configfromBackbonewon't know how to handle the nested config object.Additionally, the file is missing a final newline. Please run the code formatter to fix this.
References
get_config()method should chain to its superclass to ensure base properties are preserved, as outlined in the layer implementation guidelines. (link)ruff. A missing final newline is a formatting issue that the tool would fix. (link)