Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
195ef79
add RWKV
pass-lin Sep 28, 2025
7bc36b5
fix
pass-lin Sep 28, 2025
7d4a7a1
fix
pass-lin Sep 28, 2025
e5bb446
add inference
pass-lin Oct 7, 2025
afcff31
add inference
pass-lin Oct 7, 2025
ec0baf3
add tokenizer doc
pass-lin Oct 7, 2025
bd6c618
add doc
pass-lin Oct 7, 2025
4201a7f
add test case
pass-lin Oct 7, 2025
897a64b
fix test
pass-lin Oct 8, 2025
ff11f94
fix doc
pass-lin Oct 8, 2025
ce13d54
fix gemini review.
pass-lin Oct 20, 2025
0e36b4a
format.
pass-lin Oct 20, 2025
7218888
format.
pass-lin Oct 20, 2025
cc5815b
save tokenizer
pass-lin Oct 29, 2025
dd80464
fix tokenizer load
pass-lin Oct 29, 2025
5e8723d
fix save
pass-lin Oct 29, 2025
f223002
renew preset
pass-lin Oct 29, 2025
b2b1573
renew perset.
pass-lin Nov 1, 2025
c5ebeec
debug for remat
pass-lin Nov 3, 2025
14111c8
modify by gemini review .
pass-lin Nov 5, 2025
a88ae01
modify
pass-lin Nov 6, 2025
7f8bda7
modify
pass-lin Nov 6, 2025
00200a8
modify
pass-lin Nov 6, 2025
e97b458
modify
pass-lin Nov 6, 2025
75a4415
modify
pass-lin Nov 6, 2025
8c3638b
modify
pass-lin Nov 6, 2025
468dce1
modify rwkv casual lm.
pass-lin Nov 6, 2025
637fdcb
modify tokenizer
pass-lin Nov 9, 2025
24e67ec
fix test bug
pass-lin Nov 9, 2025
4eb4845
fix test bug
pass-lin Nov 9, 2025
be4a649
fix test bug
pass-lin Nov 9, 2025
28700d9
fix test bug
pass-lin Nov 9, 2025
2e2d5c0
fix test bug
pass-lin Nov 9, 2025
97b39cf
fix test bug
pass-lin Nov 9, 2025
44e6476
fix test bug
pass-lin Nov 9, 2025
b7ed34b
fix test bug
pass-lin Nov 9, 2025
b3e33fd
fix test bug
pass-lin Nov 9, 2025
75c8a88
modify RWKV7CausalLMPreprocessor
pass-lin Nov 13, 2025
eac1505
modify RWKV7CausalLMPreprocessor
pass-lin Nov 13, 2025
06ec6c5
modify RWKV7CausalLMPreprocessor
pass-lin Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,15 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_backbone import (
RWKV7Backbone as RWKV7Backbone,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm import (
RWKV7CausalLM as RWKV7CausalLM,
)
from keras_hub.src.models.rwkv7.rwkv7_causal_lm_preprocessor import (
RWKV7CausalLMPreprocessor as RWKV7CausalLMPreprocessor,
)
from keras_hub.src.models.sam.sam_backbone import SAMBackbone as SAMBackbone
from keras_hub.src.models.sam.sam_image_segmenter import (
SAMImageSegmenter as SAMImageSegmenter,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
RoformerV2Tokenizer as RoformerV2Tokenizer,
)
from keras_hub.src.models.rwkv7.rwkv7_tokenizer import (
RWKVTokenizer as RWKVTokenizer,
)
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
Expand Down
185 changes: 185 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.rwkv7.rwkv7_layer import RWKV7_Block


def rwkv7_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras_hub_export("keras_hub.models.RWKV7Backbone")
class RWKV7Backbone(Backbone):
"""The [RWKV-7](https://arxiv.org/abs/2503.14456) core architecture.

This network implements a Modern RNN architecture based on linear
attention mechanisms with recurrent processing, as described in the
RWKV papers. It includes the embedding lookups and RWKV-7 blocks.

The default constructor gives a fully customizable, randomly initialized
RWKV-7 model with any number of layers, heads, and embedding dimensions.
To load preset architectures and weights, use the `from_preset`
constructor.

Args:
hidden_size: int. The size of the transformer encoding and pooling
layers.
head_size: int. The size of each attention head.
num_layers: int. The number of transformer layers.
vocabulary_size: int. The size of the token vocabulary.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
gate_lora: int. LoRA dimension for gating.
mv_lora: int. LoRA dimension for value mixing.
aaa_lora: int. LoRA dimension for alpha parameters.
decay_lora: int. LoRA dimension for decay parameters.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
dropout_rate: float. Dropout rate for the dropout layer.

Examples:

```python
input_data = np.ones(shape=(1, 12), dtype="int32")


# Randomly initialized RWKV-7 decoder with custom config.
model = keras_hub.models.RWKV7Backbone(
vocabulary_size=10,
hidden_size=512,
num_layers=2,
head_size=64,
intermediate_dim=1024,
dtype="float32"
)
model(input_data)
```
"""

def __init__(
self,
hidden_size,
head_size,
num_layers,
vocabulary_size,
intermediate_dim,
gate_lora=128,
mv_lora=32,
aaa_lora=64,
decay_lora=64,
dtype=None,
dropout_rate=0,
**kwargs,
):
"""Initialize RWKV7 backbone.

Args:
hidden_size: Hidden dimension size.
head_size: Attention head size.
num_layers: Number of RWKV blocks.
vocabulary_size: Size of vocabulary.
intermediate_dim: Intermediate dimension for FFN.
gate_lora: LoRA dimension for gating.
mv_lora: LoRA dimension for value mixing.
aaa_lora: LoRA dimension for alpha parameters.
decay_lora: LoRA dimension for decay parameters.
dtype: Data type for the layer.
dropout_rate: Dropout rate for regularization.
**kwargs: Additional arguments.
"""
# === Layers ===
self.token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_size,
embeddings_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name="token_embedding",
)
self.token_embedding.build([None, None])

self.output_layer_norm = keras.layers.LayerNormalization(
epsilon=1e-5, name="output_norm"
)
self.output_layer_norm.build([None, None, hidden_size])
self.dropout = keras.layers.Dropout(
dropout_rate,
dtype=dtype,
name="dropout",
)
self.rwkv_layers = []
for i in range(num_layers):
layer = RWKV7_Block(
hidden_size,
head_size,
intermediate_dim,
gate_lora,
mv_lora,
aaa_lora,
decay_lora,
use_initial_norm=i == 0,
kernel_initializer=rwkv7_kernel_initializer(),
dtype=dtype,
name=f"rwkv_layer_{i}",
)

self.rwkv_layers.append(layer)
self.head = keras.layers.Dense(
units=vocabulary_size,
kernel_initializer=rwkv7_kernel_initializer(),
use_bias=False,
name="head",
)
# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask = ops.not_equal(token_id_input, 0)

x = self.token_embedding(token_id_input)
padding_mask = ops.cast(padding_mask, dtype=x.dtype)
v_first = None
for rwkv_layer in self.rwkv_layers:
x, v_first = rwkv_layer(x, v_first, padding_mask)
x = self.dropout(x)
sequence_output = self.output_layer_norm(x)
sequence_output = self.head(sequence_output)
super().__init__(
inputs=token_id_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the padding_mask is changed to keras.Input, make the inputs as dictionary which looks something like inputs={ "token_ids": token_id_input, "padding_mask": padding_mask_input, },

outputs=sequence_output,
dtype=dtype,
**kwargs,
)
# Initialize the graph to avoid potential errors in some cases
self.call(ops.ones([1, 16], "int32"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant, model will be already built in super().__init__(...) and this needs to be changed everytime the input signature is changed.


self.num_layers = num_layers
self.head_size = head_size
self.hidden_size = hidden_size
self.gate_lora = gate_lora
self.mv_lora = mv_lora
self.aaa_lora = aaa_lora
self.decay_lora = decay_lora
self.vocabulary_size = vocabulary_size
self.dropout_rate = dropout_rate
self.intermediate_dim = intermediate_dim

def get_config(self):
config = {
"hidden_size": self.hidden_size,
"head_size": self.head_size,
"gate_lora": self.gate_lora,
"mv_lora": self.mv_lora,
"aaa_lora": self.aaa_lora,
"decay_lora": self.decay_lora,
"vocabulary_size": self.vocabulary_size,
"dropout_rate": self.dropout_rate,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
37 changes: 37 additions & 0 deletions keras_hub/src/models/rwkv7/rwkv7_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from keras import ops

from keras_hub.src.models.rwkv7.rwkv7_backbone import RWKV7Backbone
from keras_hub.src.tests.test_case import TestCase


class RWKV7BackboneTest(TestCase):
def setUp(self):
"""
Set up the test case with default arguments and input data.
"""
self.init_kwargs = {
"vocabulary_size": 10,
"hidden_size": 16,
"num_layers": 2,
"head_size": 4,
"intermediate_dim": 32,
"gate_lora": 32,
"mv_lora": 16,
"aaa_lora": 16,
"decay_lora": 16,
}
self.input_data = ops.ones((2, 5), dtype="int32")
self.backbone = RWKV7Backbone(**self.init_kwargs)

def test_backbone_basics(self):
"""
Test basic functionality of the RWKV7 backbone.
"""
y = self.backbone(self.input_data)
self.assertEqual(y.shape, (2, 5, 10))

def test_num_parameters(self):
"""
Test that the model has the expected number of parameters.
"""
self.assertEqual(self.backbone.count_params(), 10208)
Loading
Loading