-
Notifications
You must be signed in to change notification settings - Fork 306
ADD RWKV7 #2421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ADD RWKV7 #2421
Changes from 10 commits
195ef79
7bc36b5
7d4a7a1
e5bb446
afcff31
ec0baf3
bd6c618
4201a7f
897a64b
ff11f94
ce13d54
0e36b4a
7218888
cc5815b
dd80464
5e8723d
f223002
b2b1573
c5ebeec
14111c8
a88ae01
7f8bda7
00200a8
e97b458
75a4415
8c3638b
468dce1
637fdcb
24e67ec
4eb4845
be4a649
28700d9
2e2d5c0
97b39cf
44e6476
b7ed34b
b3e33fd
75c8a88
eac1505
06ec6c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| import keras | ||
| from keras import ops | ||
|
|
||
| from keras_hub.src.api_export import keras_hub_export | ||
| from keras_hub.src.models.backbone import Backbone | ||
| from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block | ||
|
|
||
|
|
||
| def rwkv7_kernel_initializer(stddev=0.02): | ||
| return keras.initializers.TruncatedNormal(stddev=stddev) | ||
|
|
||
|
|
||
| @keras_hub_export("keras_hub.models.RWKV7Backbone") | ||
| class RWKV7Backbone(Backbone): | ||
| """The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture. | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| This network implements a Modern RNN architecture based on linear | ||
| attention mechanisms with recurrent processing, as described in the | ||
| RWKV papers. It includes the embedding lookups and RWKV-7 blocks. | ||
|
|
||
| The default constructor gives a fully customizable, randomly initialized | ||
| RWKV-7 model with any number of layers, heads, and embedding dimensions. | ||
| To load preset architectures and weights, use the `from_preset` | ||
| constructor. | ||
|
|
||
| Args: | ||
| hidden_size: int. The size of the transformer encoding and pooling | ||
| layers. | ||
| head_size: int. The size of each attention head. | ||
| num_layers: int. The number of transformer layers. | ||
| vocabulary_size: int. The size of the token vocabulary. | ||
| intermediate_dim: int. The output dimension of the first Dense layer in | ||
| a two-layer feedforward network for each transformer. | ||
| gate_lora: int. LoRA dimension for gating. | ||
| mv_lora: int. LoRA dimension for value mixing. | ||
| aaa_lora: int. LoRA dimension for alpha parameters. | ||
| decay_lora: int. LoRA dimension for decay parameters. | ||
| dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
| for model computations and weights. Note that some computations, | ||
| such as softmax and layer normalization, will always be done at | ||
| float32 precision regardless of dtype. | ||
| dropout_rate: float. Dropout rate for the dropout layer. | ||
|
|
||
| Examples: | ||
|
|
||
| ```python | ||
| input_data = np.ones(shape=(1, 12), dtype="int32") | ||
|
|
||
|
|
||
| # Randomly initialized RWKV-7 decoder with custom config. | ||
| model = keras_hub.models.RWKV7Backbone( | ||
| vocabulary_size=10, | ||
| hidden_size=512, | ||
| num_layers=2, | ||
| head_size=64, | ||
| intermediate_dim=1024, | ||
| dtype="float32" | ||
| ) | ||
| model(input_data) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_size, | ||
| head_size, | ||
| num_layers, | ||
| vocabulary_size, | ||
| intermediate_dim, | ||
| gate_lora=128, | ||
| mv_lora=32, | ||
| aaa_lora=64, | ||
| decay_lora=64, | ||
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dtype=None, | ||
| dropout_rate=0, | ||
| **kwargs, | ||
| ): | ||
| """Initialize RWKV7 backbone. | ||
|
|
||
| Args: | ||
| hidden_size: Hidden dimension size. | ||
| head_size: Attention head size. | ||
| num_layers: Number of RWKV blocks. | ||
| vocabulary_size: Size of vocabulary. | ||
| intermediate_dim: Intermediate dimension for FFN. | ||
| gate_lora: LoRA dimension for gating. | ||
| mv_lora: LoRA dimension for value mixing. | ||
| aaa_lora: LoRA dimension for alpha parameters. | ||
| decay_lora: LoRA dimension for decay parameters. | ||
| dtype: Data type for the layer. | ||
| dropout_rate: Dropout rate for regularization. | ||
| **kwargs: Additional arguments. | ||
| """ | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # === Layers === | ||
| self.token_embedding = keras.layers.Embedding( | ||
| input_dim=vocabulary_size, | ||
| output_dim=hidden_size, | ||
| embeddings_initializer=rwkv7_kernel_initializer(), | ||
| dtype=dtype, | ||
| name="token_embedding", | ||
| ) | ||
| self.token_embedding.build([None, None]) | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.output_layer_norm = keras.layers.LayerNormalization( | ||
| epsilon=1e-5, name="output_norm" | ||
| ) | ||
| self.output_layer_norm.build([None, None, hidden_size]) | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.dropout = keras.layers.Dropout( | ||
| dropout_rate, | ||
| dtype=dtype, | ||
| name="dropout", | ||
| ) | ||
| self.rwkv_layers = [] | ||
| for i in range(num_layers): | ||
| layer = RWKV7_Block( | ||
| hidden_size, | ||
| head_size, | ||
| intermediate_dim, | ||
| gate_lora, | ||
| mv_lora, | ||
| aaa_lora, | ||
| decay_lora, | ||
| use_initial_norm=i == 0, | ||
| kernel_initializer=rwkv7_kernel_initializer(), | ||
| dtype=dtype, | ||
| name=f"rwkv_layer_{i}", | ||
| ) | ||
|
|
||
| self.rwkv_layers.append(layer) | ||
| self.head = keras.layers.Dense( | ||
| units=vocabulary_size, | ||
| kernel_initializer=rwkv7_kernel_initializer(), | ||
| use_bias=False, | ||
| name="head", | ||
| ) | ||
| # === Functional Model === | ||
| token_id_input = keras.Input( | ||
| shape=(None,), dtype="int32", name="token_ids" | ||
| ) | ||
|
|
||
| padding_mask = ops.not_equal(token_id_input, 0) | ||
|
|
||
| x = self.token_embedding(token_id_input) | ||
| padding_mask = ops.cast(padding_mask, dtype=x.dtype) | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| v_first = None | ||
| for rwkv_layer in self.rwkv_layers: | ||
| x, v_first = rwkv_layer(x, v_first, padding_mask) | ||
| x = self.dropout(x) | ||
| sequence_output = self.output_layer_norm(x) | ||
| sequence_output = self.head(sequence_output) | ||
| super().__init__( | ||
| inputs=token_id_input, | ||
|
||
| outputs=sequence_output, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Initialize the graph to avoid potential errors in some cases | ||
| self.call(ops.ones([1, 16], "int32")) | ||
|
||
|
|
||
| self.num_layers = num_layers | ||
| self.head_size = head_size | ||
| self.hidden_size = hidden_size | ||
| self.gate_lora = gate_lora | ||
| self.mv_lora = mv_lora | ||
| self.aaa_lora = aaa_lora | ||
| self.decay_lora = decay_lora | ||
| self.vocabulary_size = vocabulary_size | ||
| self.dropout_rate = dropout_rate | ||
| self.intermediate_dim = intermediate_dim | ||
|
|
||
| def get_config(self): | ||
| config = { | ||
| "hidden_size": self.hidden_size, | ||
| "head_size": self.head_size, | ||
| "gate_lora": self.gate_lora, | ||
| "mv_lora": self.mv_lora, | ||
| "aaa_lora": self.aaa_lora, | ||
| "decay_lora": self.decay_lora, | ||
| "vocabulary_size": self.vocabulary_size, | ||
| "dropout_rate": self.dropout_rate, | ||
| "intermediate_dim": self.intermediate_dim, | ||
| "num_layers": self.num_layers, | ||
| } | ||
| base_config = super().get_config() | ||
| return dict(list(base_config.items()) + list(config.items())) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| from keras import ops | ||
|
|
||
| from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone | ||
| from keras_hub.src.tests.test_case import TestCase | ||
|
|
||
|
|
||
| class RWKV7BackboneTest(TestCase): | ||
| def setUp(self): | ||
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Set up the test case with default arguments and input data. | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| self.init_kwargs = { | ||
| "vocabulary_size": 10, | ||
| "hidden_size": 16, | ||
| "num_layers": 2, | ||
| "head_size": 4, | ||
| "intermediate_dim": 32, | ||
| "gate_lora": 32, | ||
| "mv_lora": 16, | ||
| "aaa_lora": 16, | ||
| "decay_lora": 16, | ||
| } | ||
| self.input_data = ops.ones((2, 5), dtype="int32") | ||
| self.backbone = RWKV7Backbone(**self.init_kwargs) | ||
|
|
||
| def test_backbone_basics(self): | ||
| """ | ||
| Test basic functionality of the RWKV7 backbone. | ||
| """ | ||
| y = self.backbone(self.input_data) | ||
| self.assertEqual(y.shape, (2, 5, 10)) | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def test_num_parameters(self): | ||
| """ | ||
| Test that the model has the expected number of parameters. | ||
| """ | ||
| self.assertEqual(self.backbone.count_params(), 10208) | ||
sachinprasadhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.