From 730bc21916470254a2b13ee98cb9a5855bfbc265 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Fri, 19 Sep 2025 20:30:17 +0530 Subject: [PATCH 01/19] Add STU layer --- keras_rs/src/layers/common.py | 62 +++++ keras_rs/src/layers/hstu_compute_output.py | 125 ++++++++++ keras_rs/src/layers/hstu_mha_attention.py | 109 +++++++++ .../src/layers/hstu_preprocess_attention.py | 44 ++++ keras_rs/src/layers/hstu_uqvk_output.py | 81 +++++++ keras_rs/src/layers/jagged_tensors.py | 112 +++++++++ keras_rs/src/layers/stu.py | 222 ++++++++++++++++++ 7 files changed, 755 insertions(+) create mode 100644 keras_rs/src/layers/common.py create mode 100644 keras_rs/src/layers/hstu_compute_output.py create mode 100644 keras_rs/src/layers/hstu_mha_attention.py create mode 100644 keras_rs/src/layers/hstu_preprocess_attention.py create mode 100644 keras_rs/src/layers/hstu_uqvk_output.py create mode 100644 keras_rs/src/layers/jagged_tensors.py create mode 100644 keras_rs/src/layers/stu.py diff --git a/keras_rs/src/layers/common.py b/keras_rs/src/layers/common.py new file mode 100644 index 00000000..4e5b9f5d --- /dev/null +++ b/keras_rs/src/layers/common.py @@ -0,0 +1,62 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def fx_unwrap_optional_tensor(optional: Optional[keras.KerasTensor]) -> keras.KerasTensor: + """Helper to unwrap optional tensors, returning a zero-tensor for uninitialized cache.""" + if optional is None: + # Returning a zero-tensor is necessary for graph tracing when the cache is uninitialized. + return ops.zeros((0,), dtype='float32') + return optional + +def get_valid_attn_mask_keras( + causal: bool, + N: int, + seq_lengths: keras.KerasTensor, + num_targets: Optional[keras.KerasTensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +): + """ + Keras implementation of the valid attention mask generation, combining + causality, sequence lengths, and target awareness. + """ + ids = ops.reshape(ops.arange(0, N, dtype="int32"), (1, N)) + max_ids = ops.reshape(seq_lengths, (-1, 1, 1)) + B = ops.shape(seq_lengths)[0] + + if contextual_seq_len > 0: + ids = ids - contextual_seq_len + 1 + ids = ops.maximum(ids, 0) + max_ids = max_ids - contextual_seq_len + 1 + + if num_targets is not None: + max_ids = max_ids - ops.reshape(num_targets, (-1, 1, 1)) + ids = ops.minimum(ids, max_ids) + row_ids = ops.broadcast_to(ops.reshape(ids, (-1, N, 1)), (B, N, N)) + col_ids = ops.broadcast_to(ops.reshape(ids, (-1, 1, N)), (B, N, N)) + else: + row_ids = ops.broadcast_to(ops.reshape(ids, (N, 1)), (N, N)) + col_ids = ops.transpose(row_ids) + row_ids = ops.reshape(row_ids, (1, N, N)) + col_ids = ops.reshape(col_ids, (1, N, N)) + max_ids = None + + row_col_dist = row_ids - col_ids + valid_attn_mask = ops.reshape(ops.eye(N, dtype="bool"), (1, N, N)) + + if not causal: + row_col_dist = ops.where(row_col_dist > 0, row_col_dist, -row_col_dist) + + valid_attn_mask = ops.logical_or(valid_attn_mask, row_col_dist > 0) + + if max_attn_len > 0: + valid_attn_mask = ops.logical_and(valid_attn_mask, row_col_dist <= max_attn_len) + + if contextual_seq_len > 0 and max_ids is not None: + valid_attn_mask = ops.logical_or( + valid_attn_mask, ops.logical_and(row_ids == 0, col_ids < max_ids) + ) + + return valid_attn_mask diff --git a/keras_rs/src/layers/hstu_compute_output.py b/keras_rs/src/layers/hstu_compute_output.py new file mode 100644 index 00000000..9d35959c --- /dev/null +++ b/keras_rs/src/layers/hstu_compute_output.py @@ -0,0 +1,125 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def keras_norm_mul_dropout( + x: keras.KerasTensor, + u: keras.KerasTensor, + weight: keras.KerasTensor, + bias: keras.KerasTensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> keras.KerasTensor: + """ + Keras 3 equivalent of pytorch_norm_mul_dropout. + Applies normalization, element-wise multiplication with u, and dropout. + Assumes keras_layer_norm is available (though the logic is inlined here). + """ + x = ops.convert_to_tensor(x, dtype='float32') + u = ops.convert_to_tensor(u, dtype='float32') + + if silu_u: + u = ops.silu(u) + + if group_norm: + raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") + else: + # Functional Layer Normalization (Simulated keras_layer_norm) + mean = ops.mean(x, axis=-1, keepdims=True) + variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) + x_norm = (x - mean) / ops.sqrt(variance + eps) + + # Apply weight and bias (Gamma * x_norm + Beta) + y_norm = x_norm * weight + bias + + # Apply u multiplication (Element-wise gating) + y = u * y_norm + + if concat_ux: + y = ops.concatenate([u, x, y], axis=1) + + # Dropout (using Keras layer for correct training=True/False behavior) + y = keras.layers.Dropout(dropout_ratio)(y, training=training) + + return ops.cast(y, dtype=x.dtype) + +def keras_hstu_compute_output( + attn: keras.KerasTensor, + u: keras.KerasTensor, + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + output_weight: keras.KerasTensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> keras.KerasTensor: + """ + Core kernel for the final residual block calculation (Attn Output -> Norm/Dropout -> MatMul -> Residual Add). + """ + y = keras_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + + # Final output: Residual addition of input (x) and transformed attention output (y @ output_weight) + output = ops.add(x, ops.matmul(y, output_weight)) + + return output + +def hstu_compute_output( + attn: keras.KerasTensor, + u: keras.KerasTensor, + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + norm_eps: float, + output_weight: keras.KerasTensor, + num_heads: int, + linear_dim: int, + dropout_ratio: float, + training: bool, + concat_ux: bool, + group_norm: bool, + recompute_y_in_backward: bool, +) -> keras.KerasTensor: + """ + Top-level wrapper for the output computation, delegates to the core Keras kernel. + """ + return keras_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=False, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) diff --git a/keras_rs/src/layers/hstu_mha_attention.py b/keras_rs/src/layers/hstu_mha_attention.py new file mode 100644 index 00000000..d7b146dc --- /dev/null +++ b/keras_rs/src/layers/hstu_mha_attention.py @@ -0,0 +1,109 @@ +import keras +from keras import ops +from typing import Tuple, Optional +from keras import layers + +# --- Assumed Imports --- +# Assumes keras_jagged_to_padded_dense, keras_dense_to_jagged, and get_valid_attn_mask_keras are available from other modules. + +def keras_pad_qkv( + q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, N: int, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: + """ + Helper to pad Q, K, V from jagged to dense format for MHA. + Assumes keras_jagged_to_padded_dense is available globally. + """ + L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] + values_q_k = ops.reshape(q, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) + + # Pad Q, K, V + padded_q_k = keras_jagged_to_padded_dense(values=values_q_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + + B = ops.shape(padded_q_k)[0]; padded_q_k = ops.reshape(padded_q_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) + padded_q = ops.transpose(padded_q_k, [0, 2, 1, 3]); padded_k = ops.transpose(padded_q_k, [0, 2, 1, 3]) + padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) + return padded_q, padded_k, padded_v + + +def keras_hstu_mha( + max_seq_len: int, alpha: float, q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, attn_scale: Optional[keras.KerasTensor] = None, **kwargs +) -> keras.KerasTensor: + """Core Keras implementation of the full Multi-Head Attention kernel (Non-Cached).""" + L, H, _ = ops.shape(q); V_dim = ops.shape(v)[2] + q, k, v = keras_pad_qkv(q, k, v, seq_offsets, max_seq_len) + qk_attn = ops.einsum("bhxa,bhya->bhxy", q, k) * alpha + + # Activation and Scaling + if attn_scale is not None: + if ops.ndim(attn_scale) > 0: + attn_scale_padded = keras_jagged_to_padded_dense(values=ops.expand_dims(attn_scale, axis=-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0) + attn_scale_padded = ops.expand_dims(ops.cast(attn_scale_padded, qk_attn.dtype), axis=1) + qk_attn = ops.silu(qk_attn) * attn_scale_padded + else: + qk_attn = ops.silu(qk_attn) / max_seq_len + + # Masking + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + valid_attn_mask = get_valid_attn_mask_keras(causal=causal, N=max_seq_len, seq_lengths=seq_lengths, **kwargs) + qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask, qk_attn.dtype), axis=1) + + # Dropout + if dropout_pr > 0.0 and training: + qk_attn = keras.layers.Dropout(dropout_pr)(qk_attn, training=training) + + # Output (Weighted Sum) + attn_dense = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v) + flat_attn_dense = ops.reshape(ops.transpose(attn_dense, [0, 2, 1, 3]), [-1, max_seq_len, H * V_dim]) + + # Convert back to jagged + jagged_output = keras_dense_to_jagged(flat_attn_dense, [seq_offsets]) + L_out = ops.shape(jagged_output)[0] + return ops.reshape(jagged_output, [L_out, H, V_dim]) + + +def keras_cached_hstu_mha( + max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, enable_tma: bool = False, +) -> keras.KerasTensor: + """Core Keras implementation of the cached attention kernel (Delta Q attends to Full K/V).""" + L_delta, H, D = ops.shape(delta_q); B = ops.shape(seq_offsets)[0] - 1; DeltaSize = L_delta // B; V_dim = ops.shape(v)[2] + + # 1. Reshape Delta Q + delta_q = ops.transpose(ops.reshape(delta_q, (B, DeltaSize, H, D)), perm=[0, 2, 1, 3]) + + # 2. Reshape Full K and V (Inputs k, v are already flattened/jagged-like) + N_full = max_seq_len + k_full = ops.transpose(ops.reshape(k, (B, N_full, H, D)), [0, 2, 1, 3]) + v_full = ops.transpose(ops.reshape(v, (B, N_full, H, V_dim)), [0, 2, 1, 3]) + + # 3. Attention Score and Activation + qk_attn = ops.einsum("bhxa,bhya->bhxy", delta_q, k_full) * alpha + qk_attn = ops.silu(qk_attn) / max_seq_len + + # 4. Masking (Slice the mask to select only the rows corresponding to the new queries) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + full_valid_attn_mask = get_valid_attn_mask_keras(causal=True, N=max_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len) + valid_attn_mask_sliced = full_valid_attn_mask[:, -DeltaSize:, :] + + qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask_sliced, qk_attn.dtype), axis=1) + + # 5. Output (Weighted Sum) + attn_output = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v_full) + + # 6. Reshape and return [L_delta, H, V_dim] + attn_output = ops.transpose(attn_output, perm=[0, 2, 1, 3]) + return ops.reshape(attn_output, (-1, H, V_dim)) + + +def delta_hstu_mha( + max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, kernel=None, enable_tma: bool = False, +) -> keras.KerasTensor: + """Top-level wrapper for cached inference MHA (delegates to core cached kernel).""" + + L_delta, H, D = ops.shape(delta_q) + # Assertions are maintained by the layer/framework where possible. + + return keras_cached_hstu_mha( + max_seq_len=max_seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, + num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, + ) diff --git a/keras_rs/src/layers/hstu_preprocess_attention.py b/keras_rs/src/layers/hstu_preprocess_attention.py new file mode 100644 index 00000000..04f70d34 --- /dev/null +++ b/keras_rs/src/layers/hstu_preprocess_attention.py @@ -0,0 +1,44 @@ +import keras +from keras import ops +from typing import Tuple, List, Optional + + +def keras_hstu_preprocess_and_attention( + x: keras.KerasTensor, norm_weight: keras.KerasTensor, norm_bias: keras.KerasTensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, + uvqk_weight: keras.KerasTensor, uvqk_bias: keras.KerasTensor, max_seq_len: int, seq_offsets: keras.KerasTensor, attn_alpha: float, causal: bool, + num_targets: Optional[keras.KerasTensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, sort_by_length: bool, prefill: bool = False, + kernel=None, **kwargs +) -> Tuple: + """ + Keras 3 implementation of the H-STU preprocess and attention workflow. + Orchestrates the conversion of input X into U, Q, K, V and subsequent MHA computation. + """ + + # --- Assertions (Skipped internal torch asserts, simplified to Keras asserts for context) --- + assert max_seq_len > 0, "max_seq_len must be larger than 0" + assert ops.ndim(x) == 2, "x must be 2-D" + assert causal is True, "only causal attention is supported." + + # 1. Compute U, Q, K, V + # Note: hstu_compute_uqvk handles the initial Norm, Linear Projection, and Split. + u, q, k, v = hstu_compute_uqvk( + x=x, norm_weight=norm_weight, norm_bias=norm_bias, norm_eps=norm_eps, + num_heads=num_heads, attn_dim=attn_dim, hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, uvqk_bias=uvqk_bias, kernel=kernel, + ) + + # 2. Compute Attention + attn_output = keras_hstu_mha( + max_seq_len=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v, + seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0, + training=False, num_targets=num_targets, max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, sort_by_length=sort_by_length, + kernel=kernel, **kwargs + ) + + # Reshape: [L, H, D] -> [L, H * D] (Flattening for the final hstu_compute_output block) + attn_output = ops.reshape(attn_output, [-1, hidden_dim * num_heads]) + + # Returns u (gating), attention output, k, and v (for caching) + return u, attn_output, k, v diff --git a/keras_rs/src/layers/hstu_uqvk_output.py b/keras_rs/src/layers/hstu_uqvk_output.py new file mode 100644 index 00000000..47b5a1aa --- /dev/null +++ b/keras_rs/src/layers/hstu_uqvk_output.py @@ -0,0 +1,81 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def keras_layer_norm( + x: keras.KerasTensor, + weight: keras.KerasTensor, + bias: keras.KerasTensor, + eps: float, +) -> keras.KerasTensor: + """ + Keras 3 functional Layer Normalization implementation. + Simulates F.layer_norm where scale/bias is applied externally. + """ + # 1. Normalize x + mean = ops.mean(x, axis=-1, keepdims=True) + variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) + x_norm = (x - mean) / ops.sqrt(variance + eps) + + # 2. Apply weight and bias (Gamma * x_norm + Beta) + return x_norm * weight + bias + +def keras_addmm( + bias: keras.KerasTensor, + input: keras.KerasTensor, + mat2: keras.KerasTensor, +) -> keras.KerasTensor: + """Keras 3 equivalent of torch.addmm (bias + input @ mat2).""" + return ops.add(bias, ops.matmul(input, mat2)) + +def hstu_compute_uqvk( + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: keras.KerasTensor, + uvqk_bias: keras.KerasTensor, + kernel=None, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: + """ + Computes the transformed tensors U, V, Q, and K from the input X. + """ + + # 1. Normalization + normed_x = keras_layer_norm( + x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + ) + + # 2. Combined Linear Projection (uvqk = bias + normed_x @ uvqk_weight) + uvqk = keras_addmm(uvqk_bias, normed_x, uvqk_weight) + + # 3. Calculate split sizes and slice + u_size = hidden_dim * num_heads + v_size = hidden_dim * num_heads + q_size = attn_dim * num_heads + k_size = attn_dim * num_heads + + start_u = 0 + start_v = u_size + start_q = u_size + v_size + start_k = u_size + v_size + q_size + L_out = ops.shape(uvqk)[0] + + u = ops.slice(uvqk, start_indices=[0, start_u], shape=[L_out, u_size]) + v = ops.slice(uvqk, start_indices=[0, start_v], shape=[L_out, v_size]) + q = ops.slice(uvqk, start_indices=[0, start_q], shape=[L_out, q_size]) + k = ops.slice(uvqk, start_indices=[0, start_k], shape=[L_out, k_size]) + + # 4. Activation and Reshape + u = ops.silu(u) + q = ops.reshape(q, [-1, num_heads, attn_dim]) + k = ops.reshape(k, [-1, num_heads, attn_dim]) + v = ops.reshape(v, [-1, num_heads, hidden_dim]) + + return u, q, k, v diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py new file mode 100644 index 00000000..87e96e68 --- /dev/null +++ b/keras_rs/src/layers/jagged_tensors.py @@ -0,0 +1,112 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +# --- Core Jagged/Dense Conversion Functions --- + +def keras_jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0.0): + """ + Keras 3 implementation to convert jagged tensor (values) into a padded dense tensor [B, N, D_flat]. + Required by MHA kernel padding (keras_pad_qkv). + """ + offsets = offsets[0] if isinstance(offsets, list) else offsets + B = ops.shape(offsets)[0] - 1 + max_len = max_lengths[0] + D_flat = ops.shape(values)[-1] + if ops.shape(values)[0] == 0: + return ops.full((B, max_len, D_flat), padding_value, dtype=values.dtype) + + def pad_one(i): + start = offsets[i]; end = offsets[i+1] + seq_len = end - start + seq = ops.slice(values, [start, 0], [seq_len, D_flat]) + if ops.equal(seq_len, 0): + return ops.full((max_len, D_flat), padding_value, dtype=values.dtype) + if seq_len < max_len: + padding_shape = ops.stack([max_len - seq_len, D_flat]) + padding = ops.full(padding_shape, padding_value, dtype=values.dtype) + return ops.concatenate([seq, padding], axis=0) + else: + return seq[:max_len] + + idxs = ops.arange(B, dtype='int32') + return ops.map(pad_one, idxs) + +def keras_dense_to_jagged( + dense: keras.KerasTensor, + x_offsets: List[keras.KerasTensor], +) -> keras.KerasTensor: + """Keras 3 implementation to convert a padded dense tensor [B, N, D] back into a jagged tensor.""" + seq_offsets = x_offsets[0] + N = ops.shape(dense)[1] + D_flat = ops.shape(dense)[2] + token_range = ops.arange(N) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + mask = ops.expand_dims(token_range, axis=0) < ops.expand_dims(seq_lengths, axis=1) + + flattened = ops.reshape(dense, [-1, D_flat]) + flattened_mask = ops.reshape(mask, [-1]) + + return flattened[flattened_mask] + +# --- Jagged Splitting and Concatenation Wrappers (Used by Caching Logic) --- + +def split_2D_jagged( + max_seq_len: int, values: keras.KerasTensor, total_len_left: Optional[int] = None, total_len_right: Optional[int] = None, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, +) -> Tuple[keras.KerasTensor, keras.KerasTensor]: + """Top-level wrapper for splitting a concatenated jagged tensor.""" + + def keras_split_2D_jagged_jagged(max_seq_len, values, offsets_left, offsets_right): + D_flat = ops.shape(values)[1]; offsets = offsets_left + offsets_right + padded_values_bnd = keras_jagged_to_padded_dense(values=values, offsets=[offsets], max_lengths=[max_seq_len], padding_value=0.0) + padded_values = ops.reshape(padded_values_bnd, [-1, D_flat]) + lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] + mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) + lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_combined = ops.reshape(lengths_left + lengths_right, [-1, 1]) + mask_left = mask < lengths_left_broadcast + mask_right = ops.logical_and(mask >= lengths_left_broadcast, mask < lengths_right_combined) + return padded_values[ops.reshape(mask_left, [-1])], padded_values[ops.reshape(mask_right, [-1])] + + def keras_split_2D_jagged_resolver(max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right): + L_total = ops.shape(values)[0] + offsets_left_non_optional = offsets_left + if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + offsets_right_non_optional = offsets_right + if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') + return keras_split_2D_jagged_jagged(max_seq_len=max_seq_len, values=values, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) + + return keras_split_2D_jagged_resolver(max_seq_len=max_seq_len, values=values, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) + + +def concat_2D_jagged( + max_seq_len: int, values_left: keras.KerasTensor, values_right: keras.KerasTensor, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, +) -> keras.KerasTensor: + """Top-level wrapper for concatenating 2D jagged tensors (used for KV cache construction).""" + + def keras_concat_2D_jagged_jagged(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + max_seq_len = max_len_left + max_len_right + lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] + padded_left = keras_jagged_to_padded_dense(values=values_left, offsets=[offsets_left], max_lengths=[max_len_left], padding_value=0.0) + padded_right = keras_jagged_to_padded_dense(values=values_right, offsets=[offsets_right], max_lengths=[max_len_right], padding_value=0.0) + concatted_dense = ops.concatenate([padded_left, padded_right], axis=1) + + lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_broadcast = ops.reshape(lengths_right, [-1, 1]) + mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) + mask = ops.logical_or(mask < lengths_left_broadcast, ops.logical_and(mask >= max_len_left, mask < max_len_left + lengths_right_broadcast)) + return concatted_dense[ops.reshape(mask, [-1])] + + def pytorch_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + L_total = ops.shape(values_left)[0] + offsets_left_non_optional = offsets_left + if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + offsets_right_non_optional = offsets_right + if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') + + if max_len_left is None: max_len_left_final = ops.max(offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) + else: max_len_left_final = max_len_left + if max_len_right is None: max_len_right_final = ops.max(offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) + else: max_len_right_final = max_len_right + + return keras_concat_2D_jagged_jagged(values_left=values_left, values_right=values_right, max_len_left=max_len_left_final, max_len_right=max_len_right_final, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) + + return pytorch_concat_2D_jagged_resolver(values_left=values_left, values_right=values_right, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py new file mode 100644 index 00000000..b9ee8c70 --- /dev/null +++ b/keras_rs/src/layers/stu.py @@ -0,0 +1,222 @@ +import abc +from typing import List, Optional, Tuple +import keras +from keras import ops +from keras import layers + +from keras_rs.src.layers.common import fx_unwrap_optional_tensor +from keras_rs.src.layers.hstu_compute_output import hstu_compute_uqvk, hstu_compute_output +from keras_rs.src.layers.hstu_preprocess_attention import keras_hstu_preprocess_and_attention +from keras_rs.src.layers.hstu_mha_attention import delta_hstu_mha +from keras_rs.src.layers.jagged_tensors import split_2D_jagged, concat_2D_jagged + + +class STULayerConfig: + def __init__(self, embedding_dim: int, num_heads: int, hidden_dim: int, attention_dim: int, + output_dropout_ratio: float = 0.3, causal: bool = True, target_aware: bool = True, + max_attn_len: Optional[int] = None, attn_alpha: Optional[float] = None, + use_group_norm: bool = False, recompute_normed_x: bool = True, + recompute_uvqk: bool = True, recompute_y: bool = True, + sort_by_length: bool = True, contextual_seq_len: int = 0): + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.attention_dim = attention_dim + self.output_dropout_ratio = output_dropout_ratio + self.causal = causal + self.target_aware = target_aware + self.max_attn_len = max_attn_len + self.attn_alpha = attn_alpha + self.use_group_norm = use_group_norm + self.recompute_normed_x = recompute_normed_x + self.recompute_uvqk = recompute_uvqk + self.recompute_y = recompute_y + self.sort_by_length = sort_by_length + self.contextual_seq_len = contextual_seq_len + + +def _update_kv_cache( + max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], orig_k_cache: Optional[keras.KerasTensor], orig_v_cache: Optional[keras.KerasTensor], orig_max_kv_caching_len: int, orig_kv_caching_offsets: Optional[keras.KerasTensor], +) -> Tuple[Optional[keras.KerasTensor], Optional[keras.KerasTensor], int, Optional[keras.KerasTensor]]: + + if kv_caching_lengths is not None: + # Keras equivalent of asynchronous_complete_cumsum + kv_caching_offsets = ops.cast(ops.cumsum(kv_caching_lengths, exclusive=True), dtype="int32") + delta_offsets = seq_offsets - kv_caching_offsets + + # NOTE: split_2D_jagged is available from jagged_tensors.py + k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(k), [-1, ops.shape(k)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(v), [-1, ops.shape(v)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + + if max_kv_caching_len == 0: + max_kv_caching_len = ops.convert_to_numpy(ops.cast(ops.max(kv_caching_lengths), dtype="int32")).item() + return (k_cache, v_cache, max_kv_caching_len, kv_caching_offsets) + else: + return (orig_k_cache, orig_v_cache, orig_max_kv_caching_len, orig_kv_caching_offsets) + + +def _construct_full_kv( + delta_k: keras.KerasTensor, delta_v: keras.KerasTensor, k_cache: keras.KerasTensor, v_cache: keras.KerasTensor, max_kv_caching_len: int, kv_caching_offsets: keras.KerasTensor, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: + L = ops.shape(delta_k)[0] + B = ops.shape(kv_caching_offsets)[0] - 1 + delta_size = L // B + + # NOTE: concat_2D_jagged is available from jagged_tensors.py + full_k = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=k_cache, values_right=delta_k, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) + full_v = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=v_cache, values_right=delta_v, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) + + # Calculate new combined offsets + delta_size_broadcast = delta_size * ops.arange(B + 1, dtype=kv_caching_offsets.dtype) + full_kv_caching_offsets = kv_caching_offsets + delta_size_broadcast + + return (full_k, full_v, max_kv_caching_len + delta_size, full_kv_caching_offsets) + + +class STU(layers.Layer, abc.ABC): + """Abstract base class for STU layers.""" + @abc.abstractmethod + def cached_forward(self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass + @abc.abstractmethod + def call(self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass + + +class STULayer(layers.Layer): + # Initialize cache properties on the instance + max_kv_caching_len: int = 0 + k_cache: Optional[keras.KerasTensor] = None + v_cache: Optional[keras.KerasTensor] = None + kv_caching_offsets: Optional[keras.KerasTensor] = None + + def __init__(self, config: STULayerConfig, is_inference: bool = False, **kwargs): + super().__init__(**kwargs) + self._config = config + self._num_heads: int = config.num_heads + self._embedding_dim: int = config.embedding_dim + self._hidden_dim: int = config.hidden_dim + self._attention_dim: int = config.attention_dim + self._output_dropout_ratio: float = config.output_dropout_ratio + self._target_aware: bool = config.target_aware + self._causal: bool = config.causal + self._max_attn_len: int = config.max_attn_len or 0 + self._attn_alpha: float = config.attn_alpha or 1.0 / (self._attention_dim**0.5) + self._use_group_norm: bool = config.use_group_norm + self._recompute_normed_x: bool = config.recompute_normed_x + self._recompute_uvqk: bool = config.recompute_uvqk + self._recompute_y: bool = config.recompute_y + self._sort_by_length: bool = config.sort_by_length + self._contextual_seq_len: int = config.contextual_seq_len + self.reset_kv_cache() + + def build(self, input_shape): + D_in = input_shape[-1] + H = self._num_heads; A = self._attention_dim; V = self._hidden_dim + output_dim_total = (V * 2 + A * 2) * H + self._uvqk_weight = self.add_weight(shape=(D_in, output_dim_total), initializer='glorot_uniform', name='uvqk_weight') + self._uvqk_beta = self.add_weight(shape=(output_dim_total,), initializer='zeros', name='uvqk_beta') + self._input_norm_weight = self.add_weight(shape=(D_in,), initializer='ones', name='input_norm_weight') + self._input_norm_bias = self.add_weight(shape=(D_in,), initializer='zeros', name='input_norm_bias') + + self._output_weight = self.add_weight(shape=(V * H, self._embedding_dim), initializer='glorot_uniform', name='output_weight') + + output_norm_shape: int = (V * H if not self._use_group_norm else H) + self._output_norm_weight = self.add_weight(shape=(output_norm_shape,), initializer='ones', name='output_norm_weight') + self._output_norm_bias = self.add_weight(shape=(output_norm_shape,), initializer='zeros', name='output_norm_bias') + self.built = True + + def reset_kv_cache(self) -> None: + self.k_cache = None; self.v_cache = None + self.kv_caching_offsets = None; self.max_kv_caching_len = 0 + + def update_kv_cache( + self, max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], + ) -> None: + # NOTE: Assumes _update_kv_cache is available + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = (_update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets,)) + + def construct_full_kv(self, delta_k: keras.KerasTensor, delta_v: keras.KerasTensor,) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: + # NOTE: Assumes _construct_full_kv is available + return _construct_full_kv(delta_k=delta_k, delta_v=delta_v, k_cache=fx_unwrap_optional_tensor(self.k_cache), v_cache=fx_unwrap_optional_tensor(self.v_cache), max_kv_caching_len=self.max_kv_caching_len, kv_caching_offsets=fx_unwrap_optional_tensor(self.kv_caching_offsets),) + + def call( # Standard Keras forward method + self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + + u, attn_output, k, v = keras_hstu_preprocess_and_attention( + x=x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, + num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, + max_seq_len=max_seq_len, seq_offsets=x_offsets, attn_alpha=self._attn_alpha, + causal=self._causal, num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, contextual_seq_len=self._contextual_seq_len, + recompute_uvqk_in_backward=self._recompute_uvqk, recompute_normed_x_in_backward=self._recompute_normed_x, + sort_by_length=self._sort_by_length, prefill=kv_caching_lengths is not None, + ) + + self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=x_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) + + return hstu_compute_output( + attn=attn_output, u=u, x=x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, + norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, + group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, + concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, + ) + + def cached_forward( # Called for token-by-token generation + self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + + delta_u, delta_q, delta_k, delta_v = hstu_compute_uqvk( + x=delta_x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, + num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, + ) + + A = self._attention_dim; V = self._hidden_dim; H = self._num_heads + k_flat = ops.reshape(delta_k, [-1, H * A]) + v_flat = ops.reshape(delta_v, [-1, H * V]) + + k_full, v_full, max_seq_len, seq_offsets = self.construct_full_kv(delta_k=k_flat, delta_v=v_flat) + + self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k_full, v=v_full, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) + + # Reshape K and V back to [L_full, H, D] for attention calculation + k = ops.reshape(k_full, [-1, H, A]) + v = ops.reshape(v_full, [-1, H, V]) + + + delta_attn_output = delta_hstu_mha( + max_seq_len=max_seq_len, alpha=self._attn_alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, + num_targets=num_targets if self._target_aware else None, max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + ) + + delta_attn_output = ops.reshape(delta_attn_output, [-1, V * H]) + + + return hstu_compute_output( + attn=delta_attn_output, u=delta_u, x=delta_x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, + norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, + group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, + concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, + ) + + +class STUStack(layers.Layer): + def __init__(self, stu_layers: List[STULayer], is_inference: bool = False, **kwargs): + super().__init__(**kwargs) + self._stu_layers = stu_layers + + def call( + self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + for layer in self._stu_layers: + x = layer(x=x, x_lengths=x_lengths, x_offsets=x_offsets, max_seq_len=max_seq_len, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) + return x + + def cached_forward( + self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + for layer in self._stu_layers: + delta_x = layer.cached_forward(delta_x=delta_x, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) + return delta_x From 1906f2d0f2910ff2008c77992957ff4ac1f06f93 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:23:48 +0530 Subject: [PATCH 02/19] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index b9ee8c70..a3b46fef 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -132,7 +132,7 @@ def update_kv_cache( self, max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], ) -> None: # NOTE: Assumes _update_kv_cache is available - self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = (_update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets,)) + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = _update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets) def construct_full_kv(self, delta_k: keras.KerasTensor, delta_v: keras.KerasTensor,) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: # NOTE: Assumes _construct_full_kv is available From d0add4fdb5ef64372706dc8ba872367931c8bfef Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:24:17 +0530 Subject: [PATCH 03/19] Update keras_rs/src/layers/common.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/common.py b/keras_rs/src/layers/common.py index 4e5b9f5d..1ebea2c3 100644 --- a/keras_rs/src/layers/common.py +++ b/keras_rs/src/layers/common.py @@ -17,7 +17,7 @@ def get_valid_attn_mask_keras( max_attn_len: int = 0, contextual_seq_len: int = 0, min_full_attn_seq_len: int = 0, -): +) -> keras.KerasTensor: """ Keras implementation of the valid attention mask generation, combining causality, sequence lengths, and target awareness. From eb98ae178105aa72861af6e2f0094e77c3515b6d Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:24:34 +0530 Subject: [PATCH 04/19] Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/jagged_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py index 87e96e68..84e90aa8 100644 --- a/keras_rs/src/layers/jagged_tensors.py +++ b/keras_rs/src/layers/jagged_tensors.py @@ -95,7 +95,7 @@ def keras_concat_2D_jagged_jagged(values_left, values_right, max_len_left, max_l mask = ops.logical_or(mask < lengths_left_broadcast, ops.logical_and(mask >= max_len_left, mask < max_len_left + lengths_right_broadcast)) return concatted_dense[ops.reshape(mask, [-1])] - def pytorch_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + def keras_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): L_total = ops.shape(values_left)[0] offsets_left_non_optional = offsets_left if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') From 0347b51ae8531934518fa315943f0001985985e6 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:54:52 +0530 Subject: [PATCH 05/19] Update keras_rs/src/layers/hstu_mha_attention.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/hstu_mha_attention.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/keras_rs/src/layers/hstu_mha_attention.py b/keras_rs/src/layers/hstu_mha_attention.py index d7b146dc..0986e402 100644 --- a/keras_rs/src/layers/hstu_mha_attention.py +++ b/keras_rs/src/layers/hstu_mha_attention.py @@ -14,14 +14,16 @@ def keras_pad_qkv( Assumes keras_jagged_to_padded_dense is available globally. """ L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] - values_q_k = ops.reshape(q, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) + values_q = ops.reshape(q, [L, H * D]); values_k = ops.reshape(k, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) # Pad Q, K, V - padded_q_k = keras_jagged_to_padded_dense(values=values_q_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_q = keras_jagged_to_padded_dense(values=values_q, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_k = keras_jagged_to_padded_dense(values=values_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - B = ops.shape(padded_q_k)[0]; padded_q_k = ops.reshape(padded_q_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) - padded_q = ops.transpose(padded_q_k, [0, 2, 1, 3]); padded_k = ops.transpose(padded_q_k, [0, 2, 1, 3]) + B = ops.shape(padded_q)[0] + padded_q = ops.reshape(padded_q, [B, N, H, D]); padded_k = ops.reshape(padded_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) + padded_q = ops.transpose(padded_q, [0, 2, 1, 3]); padded_k = ops.transpose(padded_k, [0, 2, 1, 3]) padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) return padded_q, padded_k, padded_v From 570df2b35ea162a77d014911f1f5382942edc042 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:55:35 +0530 Subject: [PATCH 06/19] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index a3b46fef..79f6f52e 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -45,8 +45,16 @@ def _update_kv_cache( delta_offsets = seq_offsets - kv_caching_offsets # NOTE: split_2D_jagged is available from jagged_tensors.py - k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(k), [-1, ops.shape(k)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) - v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(v), [-1, ops.shape(v)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + if k is not None: + k_values = ops.reshape(k, [ops.shape(k)[0], -1]) + k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=k_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + else: + k_cache = fx_unwrap_optional_tensor(k) + if v is not None: + v_values = ops.reshape(v, [ops.shape(v)[0], -1]) + v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=v_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + else: + v_cache = fx_unwrap_optional_tensor(v) if max_kv_caching_len == 0: max_kv_caching_len = ops.convert_to_numpy(ops.cast(ops.max(kv_caching_lengths), dtype="int32")).item() From 6f3dcef1938fc9b19c8be1c96d3a83f437e11bfd Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:56:04 +0530 Subject: [PATCH 07/19] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index 79f6f52e..f395004c 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -5,7 +5,8 @@ from keras import layers from keras_rs.src.layers.common import fx_unwrap_optional_tensor -from keras_rs.src.layers.hstu_compute_output import hstu_compute_uqvk, hstu_compute_output +from keras_rs.src.layers.hstu_compute_output import hstu_compute_output +from keras_rs.src.layers.hstu_uqvk_output import hstu_compute_uqvk from keras_rs.src.layers.hstu_preprocess_attention import keras_hstu_preprocess_and_attention from keras_rs.src.layers.hstu_mha_attention import delta_hstu_mha from keras_rs.src.layers.jagged_tensors import split_2D_jagged, concat_2D_jagged From c80e598d9831e8e0bf418433350459718b2da652 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:56:53 +0530 Subject: [PATCH 08/19] Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/jagged_tensors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py index 84e90aa8..23551877 100644 --- a/keras_rs/src/layers/jagged_tensors.py +++ b/keras_rs/src/layers/jagged_tensors.py @@ -70,7 +70,10 @@ def keras_split_2D_jagged_jagged(max_seq_len, values, offsets_left, offsets_righ def keras_split_2D_jagged_resolver(max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right): L_total = ops.shape(values)[0] offsets_left_non_optional = offsets_left - if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + if offsets_left is None: + if max_len_left is None: + raise ValueError("Either offsets_left or max_len_left must be provided.") + offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') offsets_right_non_optional = offsets_right if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') return keras_split_2D_jagged_jagged(max_seq_len=max_seq_len, values=values, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) From 8073dcbd36847dedda7dffa86566236b8d179548 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:57:54 +0530 Subject: [PATCH 09/19] Update keras_rs/src/layers/hstu_compute_output.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/hstu_compute_output.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_rs/src/layers/hstu_compute_output.py b/keras_rs/src/layers/hstu_compute_output.py index 9d35959c..d40107c4 100644 --- a/keras_rs/src/layers/hstu_compute_output.py +++ b/keras_rs/src/layers/hstu_compute_output.py @@ -31,9 +31,7 @@ def keras_norm_mul_dropout( raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") else: # Functional Layer Normalization (Simulated keras_layer_norm) - mean = ops.mean(x, axis=-1, keepdims=True) - variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) - x_norm = (x - mean) / ops.sqrt(variance + eps) + x_norm = ops.layer_norm(x, axis=-1, epsilon=eps) # Apply weight and bias (Gamma * x_norm + Beta) y_norm = x_norm * weight + bias From 8cb2d981dd2051ab7edd2d6a6e42682704dfc199 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Fri, 24 Oct 2025 13:40:51 +0530 Subject: [PATCH 10/19] Add list mle loss --- keras_rs/api/losses/__init__.py | 1 + keras_rs/src/layers/common.py | 62 ----- keras_rs/src/layers/hstu_compute_output.py | 123 ---------- keras_rs/src/layers/hstu_mha_attention.py | 111 --------- .../src/layers/hstu_preprocess_attention.py | 44 ---- keras_rs/src/layers/hstu_uqvk_output.py | 81 ------ keras_rs/src/layers/jagged_tensors.py | 115 --------- keras_rs/src/layers/stu.py | 231 ------------------ keras_rs/src/losses/list_mle_loss.py | 217 ++++++++++++++++ keras_rs/src/losses/list_mle_loss_test.py | 99 ++++++++ 10 files changed, 317 insertions(+), 767 deletions(-) delete mode 100644 keras_rs/src/layers/common.py delete mode 100644 keras_rs/src/layers/hstu_compute_output.py delete mode 100644 keras_rs/src/layers/hstu_mha_attention.py delete mode 100644 keras_rs/src/layers/hstu_preprocess_attention.py delete mode 100644 keras_rs/src/layers/hstu_uqvk_output.py delete mode 100644 keras_rs/src/layers/jagged_tensors.py delete mode 100644 keras_rs/src/layers/stu.py create mode 100644 keras_rs/src/losses/list_mle_loss.py create mode 100644 keras_rs/src/losses/list_mle_loss_test.py diff --git a/keras_rs/api/losses/__init__.py b/keras_rs/api/losses/__init__.py index 152b4496..f9110d58 100644 --- a/keras_rs/api/losses/__init__.py +++ b/keras_rs/api/losses/__init__.py @@ -4,6 +4,7 @@ since your modifications would be overwritten. """ +from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss from keras_rs.src.losses.pairwise_hinge_loss import ( PairwiseHingeLoss as PairwiseHingeLoss, ) diff --git a/keras_rs/src/layers/common.py b/keras_rs/src/layers/common.py deleted file mode 100644 index 1ebea2c3..00000000 --- a/keras_rs/src/layers/common.py +++ /dev/null @@ -1,62 +0,0 @@ -import keras -from keras import ops -from typing import List, Optional, Tuple - -def fx_unwrap_optional_tensor(optional: Optional[keras.KerasTensor]) -> keras.KerasTensor: - """Helper to unwrap optional tensors, returning a zero-tensor for uninitialized cache.""" - if optional is None: - # Returning a zero-tensor is necessary for graph tracing when the cache is uninitialized. - return ops.zeros((0,), dtype='float32') - return optional - -def get_valid_attn_mask_keras( - causal: bool, - N: int, - seq_lengths: keras.KerasTensor, - num_targets: Optional[keras.KerasTensor] = None, - max_attn_len: int = 0, - contextual_seq_len: int = 0, - min_full_attn_seq_len: int = 0, -) -> keras.KerasTensor: - """ - Keras implementation of the valid attention mask generation, combining - causality, sequence lengths, and target awareness. - """ - ids = ops.reshape(ops.arange(0, N, dtype="int32"), (1, N)) - max_ids = ops.reshape(seq_lengths, (-1, 1, 1)) - B = ops.shape(seq_lengths)[0] - - if contextual_seq_len > 0: - ids = ids - contextual_seq_len + 1 - ids = ops.maximum(ids, 0) - max_ids = max_ids - contextual_seq_len + 1 - - if num_targets is not None: - max_ids = max_ids - ops.reshape(num_targets, (-1, 1, 1)) - ids = ops.minimum(ids, max_ids) - row_ids = ops.broadcast_to(ops.reshape(ids, (-1, N, 1)), (B, N, N)) - col_ids = ops.broadcast_to(ops.reshape(ids, (-1, 1, N)), (B, N, N)) - else: - row_ids = ops.broadcast_to(ops.reshape(ids, (N, 1)), (N, N)) - col_ids = ops.transpose(row_ids) - row_ids = ops.reshape(row_ids, (1, N, N)) - col_ids = ops.reshape(col_ids, (1, N, N)) - max_ids = None - - row_col_dist = row_ids - col_ids - valid_attn_mask = ops.reshape(ops.eye(N, dtype="bool"), (1, N, N)) - - if not causal: - row_col_dist = ops.where(row_col_dist > 0, row_col_dist, -row_col_dist) - - valid_attn_mask = ops.logical_or(valid_attn_mask, row_col_dist > 0) - - if max_attn_len > 0: - valid_attn_mask = ops.logical_and(valid_attn_mask, row_col_dist <= max_attn_len) - - if contextual_seq_len > 0 and max_ids is not None: - valid_attn_mask = ops.logical_or( - valid_attn_mask, ops.logical_and(row_ids == 0, col_ids < max_ids) - ) - - return valid_attn_mask diff --git a/keras_rs/src/layers/hstu_compute_output.py b/keras_rs/src/layers/hstu_compute_output.py deleted file mode 100644 index d40107c4..00000000 --- a/keras_rs/src/layers/hstu_compute_output.py +++ /dev/null @@ -1,123 +0,0 @@ -import keras -from keras import ops -from typing import List, Optional, Tuple - -def keras_norm_mul_dropout( - x: keras.KerasTensor, - u: keras.KerasTensor, - weight: keras.KerasTensor, - bias: keras.KerasTensor, - eps: float, - dropout_ratio: float, - training: bool, - silu_u: bool = False, - concat_ux: bool = False, - group_norm: bool = False, - num_heads: int = 1, - linear_dim: int = -1, -) -> keras.KerasTensor: - """ - Keras 3 equivalent of pytorch_norm_mul_dropout. - Applies normalization, element-wise multiplication with u, and dropout. - Assumes keras_layer_norm is available (though the logic is inlined here). - """ - x = ops.convert_to_tensor(x, dtype='float32') - u = ops.convert_to_tensor(u, dtype='float32') - - if silu_u: - u = ops.silu(u) - - if group_norm: - raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") - else: - # Functional Layer Normalization (Simulated keras_layer_norm) - x_norm = ops.layer_norm(x, axis=-1, epsilon=eps) - - # Apply weight and bias (Gamma * x_norm + Beta) - y_norm = x_norm * weight + bias - - # Apply u multiplication (Element-wise gating) - y = u * y_norm - - if concat_ux: - y = ops.concatenate([u, x, y], axis=1) - - # Dropout (using Keras layer for correct training=True/False behavior) - y = keras.layers.Dropout(dropout_ratio)(y, training=training) - - return ops.cast(y, dtype=x.dtype) - -def keras_hstu_compute_output( - attn: keras.KerasTensor, - u: keras.KerasTensor, - x: keras.KerasTensor, - norm_weight: keras.KerasTensor, - norm_bias: keras.KerasTensor, - output_weight: keras.KerasTensor, - eps: float, - dropout_ratio: float, - training: bool, - silu_u: bool = False, - concat_ux: bool = False, - group_norm: bool = False, - num_heads: int = 1, - linear_dim: int = -1, -) -> keras.KerasTensor: - """ - Core kernel for the final residual block calculation (Attn Output -> Norm/Dropout -> MatMul -> Residual Add). - """ - y = keras_norm_mul_dropout( - x=attn, - u=u, - weight=norm_weight, - bias=norm_bias, - eps=eps, - dropout_ratio=dropout_ratio, - training=training, - silu_u=silu_u, - concat_ux=concat_ux, - group_norm=group_norm, - num_heads=num_heads, - linear_dim=linear_dim, - ) - - # Final output: Residual addition of input (x) and transformed attention output (y @ output_weight) - output = ops.add(x, ops.matmul(y, output_weight)) - - return output - -def hstu_compute_output( - attn: keras.KerasTensor, - u: keras.KerasTensor, - x: keras.KerasTensor, - norm_weight: keras.KerasTensor, - norm_bias: keras.KerasTensor, - norm_eps: float, - output_weight: keras.KerasTensor, - num_heads: int, - linear_dim: int, - dropout_ratio: float, - training: bool, - concat_ux: bool, - group_norm: bool, - recompute_y_in_backward: bool, -) -> keras.KerasTensor: - """ - Top-level wrapper for the output computation, delegates to the core Keras kernel. - """ - return keras_hstu_compute_output( - attn=attn, - u=u, - x=x, - norm_weight=norm_weight, - norm_bias=norm_bias, - output_weight=output_weight, - eps=norm_eps, - dropout_ratio=dropout_ratio, - training=training, - silu_u=False, - concat_ux=concat_ux, - group_norm=group_norm, - num_heads=num_heads, - linear_dim=linear_dim, - ) diff --git a/keras_rs/src/layers/hstu_mha_attention.py b/keras_rs/src/layers/hstu_mha_attention.py deleted file mode 100644 index 0986e402..00000000 --- a/keras_rs/src/layers/hstu_mha_attention.py +++ /dev/null @@ -1,111 +0,0 @@ -import keras -from keras import ops -from typing import Tuple, Optional -from keras import layers - -# --- Assumed Imports --- -# Assumes keras_jagged_to_padded_dense, keras_dense_to_jagged, and get_valid_attn_mask_keras are available from other modules. - -def keras_pad_qkv( - q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, N: int, -) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: - """ - Helper to pad Q, K, V from jagged to dense format for MHA. - Assumes keras_jagged_to_padded_dense is available globally. - """ - L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] - values_q = ops.reshape(q, [L, H * D]); values_k = ops.reshape(k, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) - - # Pad Q, K, V - padded_q = keras_jagged_to_padded_dense(values=values_q, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - padded_k = keras_jagged_to_padded_dense(values=values_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - - B = ops.shape(padded_q)[0] - padded_q = ops.reshape(padded_q, [B, N, H, D]); padded_k = ops.reshape(padded_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) - padded_q = ops.transpose(padded_q, [0, 2, 1, 3]); padded_k = ops.transpose(padded_k, [0, 2, 1, 3]) - padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) - return padded_q, padded_k, padded_v - - -def keras_hstu_mha( - max_seq_len: int, alpha: float, q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, attn_scale: Optional[keras.KerasTensor] = None, **kwargs -) -> keras.KerasTensor: - """Core Keras implementation of the full Multi-Head Attention kernel (Non-Cached).""" - L, H, _ = ops.shape(q); V_dim = ops.shape(v)[2] - q, k, v = keras_pad_qkv(q, k, v, seq_offsets, max_seq_len) - qk_attn = ops.einsum("bhxa,bhya->bhxy", q, k) * alpha - - # Activation and Scaling - if attn_scale is not None: - if ops.ndim(attn_scale) > 0: - attn_scale_padded = keras_jagged_to_padded_dense(values=ops.expand_dims(attn_scale, axis=-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0) - attn_scale_padded = ops.expand_dims(ops.cast(attn_scale_padded, qk_attn.dtype), axis=1) - qk_attn = ops.silu(qk_attn) * attn_scale_padded - else: - qk_attn = ops.silu(qk_attn) / max_seq_len - - # Masking - seq_lengths = seq_offsets[1:] - seq_offsets[:-1] - valid_attn_mask = get_valid_attn_mask_keras(causal=causal, N=max_seq_len, seq_lengths=seq_lengths, **kwargs) - qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask, qk_attn.dtype), axis=1) - - # Dropout - if dropout_pr > 0.0 and training: - qk_attn = keras.layers.Dropout(dropout_pr)(qk_attn, training=training) - - # Output (Weighted Sum) - attn_dense = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v) - flat_attn_dense = ops.reshape(ops.transpose(attn_dense, [0, 2, 1, 3]), [-1, max_seq_len, H * V_dim]) - - # Convert back to jagged - jagged_output = keras_dense_to_jagged(flat_attn_dense, [seq_offsets]) - L_out = ops.shape(jagged_output)[0] - return ops.reshape(jagged_output, [L_out, H, V_dim]) - - -def keras_cached_hstu_mha( - max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, enable_tma: bool = False, -) -> keras.KerasTensor: - """Core Keras implementation of the cached attention kernel (Delta Q attends to Full K/V).""" - L_delta, H, D = ops.shape(delta_q); B = ops.shape(seq_offsets)[0] - 1; DeltaSize = L_delta // B; V_dim = ops.shape(v)[2] - - # 1. Reshape Delta Q - delta_q = ops.transpose(ops.reshape(delta_q, (B, DeltaSize, H, D)), perm=[0, 2, 1, 3]) - - # 2. Reshape Full K and V (Inputs k, v are already flattened/jagged-like) - N_full = max_seq_len - k_full = ops.transpose(ops.reshape(k, (B, N_full, H, D)), [0, 2, 1, 3]) - v_full = ops.transpose(ops.reshape(v, (B, N_full, H, V_dim)), [0, 2, 1, 3]) - - # 3. Attention Score and Activation - qk_attn = ops.einsum("bhxa,bhya->bhxy", delta_q, k_full) * alpha - qk_attn = ops.silu(qk_attn) / max_seq_len - - # 4. Masking (Slice the mask to select only the rows corresponding to the new queries) - seq_lengths = seq_offsets[1:] - seq_offsets[:-1] - full_valid_attn_mask = get_valid_attn_mask_keras(causal=True, N=max_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len) - valid_attn_mask_sliced = full_valid_attn_mask[:, -DeltaSize:, :] - - qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask_sliced, qk_attn.dtype), axis=1) - - # 5. Output (Weighted Sum) - attn_output = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v_full) - - # 6. Reshape and return [L_delta, H, V_dim] - attn_output = ops.transpose(attn_output, perm=[0, 2, 1, 3]) - return ops.reshape(attn_output, (-1, H, V_dim)) - - -def delta_hstu_mha( - max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, kernel=None, enable_tma: bool = False, -) -> keras.KerasTensor: - """Top-level wrapper for cached inference MHA (delegates to core cached kernel).""" - - L_delta, H, D = ops.shape(delta_q) - # Assertions are maintained by the layer/framework where possible. - - return keras_cached_hstu_mha( - max_seq_len=max_seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, - num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, - ) diff --git a/keras_rs/src/layers/hstu_preprocess_attention.py b/keras_rs/src/layers/hstu_preprocess_attention.py deleted file mode 100644 index 04f70d34..00000000 --- a/keras_rs/src/layers/hstu_preprocess_attention.py +++ /dev/null @@ -1,44 +0,0 @@ -import keras -from keras import ops -from typing import Tuple, List, Optional - - -def keras_hstu_preprocess_and_attention( - x: keras.KerasTensor, norm_weight: keras.KerasTensor, norm_bias: keras.KerasTensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, - uvqk_weight: keras.KerasTensor, uvqk_bias: keras.KerasTensor, max_seq_len: int, seq_offsets: keras.KerasTensor, attn_alpha: float, causal: bool, - num_targets: Optional[keras.KerasTensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool, - recompute_normed_x_in_backward: bool, sort_by_length: bool, prefill: bool = False, - kernel=None, **kwargs -) -> Tuple: - """ - Keras 3 implementation of the H-STU preprocess and attention workflow. - Orchestrates the conversion of input X into U, Q, K, V and subsequent MHA computation. - """ - - # --- Assertions (Skipped internal torch asserts, simplified to Keras asserts for context) --- - assert max_seq_len > 0, "max_seq_len must be larger than 0" - assert ops.ndim(x) == 2, "x must be 2-D" - assert causal is True, "only causal attention is supported." - - # 1. Compute U, Q, K, V - # Note: hstu_compute_uqvk handles the initial Norm, Linear Projection, and Split. - u, q, k, v = hstu_compute_uqvk( - x=x, norm_weight=norm_weight, norm_bias=norm_bias, norm_eps=norm_eps, - num_heads=num_heads, attn_dim=attn_dim, hidden_dim=hidden_dim, - uvqk_weight=uvqk_weight, uvqk_bias=uvqk_bias, kernel=kernel, - ) - - # 2. Compute Attention - attn_output = keras_hstu_mha( - max_seq_len=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v, - seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0, - training=False, num_targets=num_targets, max_attn_len=max_attn_len, - contextual_seq_len=contextual_seq_len, sort_by_length=sort_by_length, - kernel=kernel, **kwargs - ) - - # Reshape: [L, H, D] -> [L, H * D] (Flattening for the final hstu_compute_output block) - attn_output = ops.reshape(attn_output, [-1, hidden_dim * num_heads]) - - # Returns u (gating), attention output, k, and v (for caching) - return u, attn_output, k, v diff --git a/keras_rs/src/layers/hstu_uqvk_output.py b/keras_rs/src/layers/hstu_uqvk_output.py deleted file mode 100644 index 47b5a1aa..00000000 --- a/keras_rs/src/layers/hstu_uqvk_output.py +++ /dev/null @@ -1,81 +0,0 @@ -import keras -from keras import ops -from typing import List, Optional, Tuple - -def keras_layer_norm( - x: keras.KerasTensor, - weight: keras.KerasTensor, - bias: keras.KerasTensor, - eps: float, -) -> keras.KerasTensor: - """ - Keras 3 functional Layer Normalization implementation. - Simulates F.layer_norm where scale/bias is applied externally. - """ - # 1. Normalize x - mean = ops.mean(x, axis=-1, keepdims=True) - variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) - x_norm = (x - mean) / ops.sqrt(variance + eps) - - # 2. Apply weight and bias (Gamma * x_norm + Beta) - return x_norm * weight + bias - -def keras_addmm( - bias: keras.KerasTensor, - input: keras.KerasTensor, - mat2: keras.KerasTensor, -) -> keras.KerasTensor: - """Keras 3 equivalent of torch.addmm (bias + input @ mat2).""" - return ops.add(bias, ops.matmul(input, mat2)) - -def hstu_compute_uqvk( - x: keras.KerasTensor, - norm_weight: keras.KerasTensor, - norm_bias: keras.KerasTensor, - norm_eps: float, - num_heads: int, - attn_dim: int, - hidden_dim: int, - uvqk_weight: keras.KerasTensor, - uvqk_bias: keras.KerasTensor, - kernel=None, -) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: - """ - Computes the transformed tensors U, V, Q, and K from the input X. - """ - - # 1. Normalization - normed_x = keras_layer_norm( - x, - weight=norm_weight, - bias=norm_bias, - eps=norm_eps, - ) - - # 2. Combined Linear Projection (uvqk = bias + normed_x @ uvqk_weight) - uvqk = keras_addmm(uvqk_bias, normed_x, uvqk_weight) - - # 3. Calculate split sizes and slice - u_size = hidden_dim * num_heads - v_size = hidden_dim * num_heads - q_size = attn_dim * num_heads - k_size = attn_dim * num_heads - - start_u = 0 - start_v = u_size - start_q = u_size + v_size - start_k = u_size + v_size + q_size - L_out = ops.shape(uvqk)[0] - - u = ops.slice(uvqk, start_indices=[0, start_u], shape=[L_out, u_size]) - v = ops.slice(uvqk, start_indices=[0, start_v], shape=[L_out, v_size]) - q = ops.slice(uvqk, start_indices=[0, start_q], shape=[L_out, q_size]) - k = ops.slice(uvqk, start_indices=[0, start_k], shape=[L_out, k_size]) - - # 4. Activation and Reshape - u = ops.silu(u) - q = ops.reshape(q, [-1, num_heads, attn_dim]) - k = ops.reshape(k, [-1, num_heads, attn_dim]) - v = ops.reshape(v, [-1, num_heads, hidden_dim]) - - return u, q, k, v diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py deleted file mode 100644 index 23551877..00000000 --- a/keras_rs/src/layers/jagged_tensors.py +++ /dev/null @@ -1,115 +0,0 @@ -import keras -from keras import ops -from typing import List, Optional, Tuple - -# --- Core Jagged/Dense Conversion Functions --- - -def keras_jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0.0): - """ - Keras 3 implementation to convert jagged tensor (values) into a padded dense tensor [B, N, D_flat]. - Required by MHA kernel padding (keras_pad_qkv). - """ - offsets = offsets[0] if isinstance(offsets, list) else offsets - B = ops.shape(offsets)[0] - 1 - max_len = max_lengths[0] - D_flat = ops.shape(values)[-1] - if ops.shape(values)[0] == 0: - return ops.full((B, max_len, D_flat), padding_value, dtype=values.dtype) - - def pad_one(i): - start = offsets[i]; end = offsets[i+1] - seq_len = end - start - seq = ops.slice(values, [start, 0], [seq_len, D_flat]) - if ops.equal(seq_len, 0): - return ops.full((max_len, D_flat), padding_value, dtype=values.dtype) - if seq_len < max_len: - padding_shape = ops.stack([max_len - seq_len, D_flat]) - padding = ops.full(padding_shape, padding_value, dtype=values.dtype) - return ops.concatenate([seq, padding], axis=0) - else: - return seq[:max_len] - - idxs = ops.arange(B, dtype='int32') - return ops.map(pad_one, idxs) - -def keras_dense_to_jagged( - dense: keras.KerasTensor, - x_offsets: List[keras.KerasTensor], -) -> keras.KerasTensor: - """Keras 3 implementation to convert a padded dense tensor [B, N, D] back into a jagged tensor.""" - seq_offsets = x_offsets[0] - N = ops.shape(dense)[1] - D_flat = ops.shape(dense)[2] - token_range = ops.arange(N) - seq_lengths = seq_offsets[1:] - seq_offsets[:-1] - mask = ops.expand_dims(token_range, axis=0) < ops.expand_dims(seq_lengths, axis=1) - - flattened = ops.reshape(dense, [-1, D_flat]) - flattened_mask = ops.reshape(mask, [-1]) - - return flattened[flattened_mask] - -# --- Jagged Splitting and Concatenation Wrappers (Used by Caching Logic) --- - -def split_2D_jagged( - max_seq_len: int, values: keras.KerasTensor, total_len_left: Optional[int] = None, total_len_right: Optional[int] = None, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, -) -> Tuple[keras.KerasTensor, keras.KerasTensor]: - """Top-level wrapper for splitting a concatenated jagged tensor.""" - - def keras_split_2D_jagged_jagged(max_seq_len, values, offsets_left, offsets_right): - D_flat = ops.shape(values)[1]; offsets = offsets_left + offsets_right - padded_values_bnd = keras_jagged_to_padded_dense(values=values, offsets=[offsets], max_lengths=[max_seq_len], padding_value=0.0) - padded_values = ops.reshape(padded_values_bnd, [-1, D_flat]) - lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] - mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) - lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_combined = ops.reshape(lengths_left + lengths_right, [-1, 1]) - mask_left = mask < lengths_left_broadcast - mask_right = ops.logical_and(mask >= lengths_left_broadcast, mask < lengths_right_combined) - return padded_values[ops.reshape(mask_left, [-1])], padded_values[ops.reshape(mask_right, [-1])] - - def keras_split_2D_jagged_resolver(max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right): - L_total = ops.shape(values)[0] - offsets_left_non_optional = offsets_left - if offsets_left is None: - if max_len_left is None: - raise ValueError("Either offsets_left or max_len_left must be provided.") - offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') - offsets_right_non_optional = offsets_right - if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') - return keras_split_2D_jagged_jagged(max_seq_len=max_seq_len, values=values, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) - - return keras_split_2D_jagged_resolver(max_seq_len=max_seq_len, values=values, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) - - -def concat_2D_jagged( - max_seq_len: int, values_left: keras.KerasTensor, values_right: keras.KerasTensor, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, -) -> keras.KerasTensor: - """Top-level wrapper for concatenating 2D jagged tensors (used for KV cache construction).""" - - def keras_concat_2D_jagged_jagged(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): - max_seq_len = max_len_left + max_len_right - lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] - padded_left = keras_jagged_to_padded_dense(values=values_left, offsets=[offsets_left], max_lengths=[max_len_left], padding_value=0.0) - padded_right = keras_jagged_to_padded_dense(values=values_right, offsets=[offsets_right], max_lengths=[max_len_right], padding_value=0.0) - concatted_dense = ops.concatenate([padded_left, padded_right], axis=1) - - lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_broadcast = ops.reshape(lengths_right, [-1, 1]) - mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) - mask = ops.logical_or(mask < lengths_left_broadcast, ops.logical_and(mask >= max_len_left, mask < max_len_left + lengths_right_broadcast)) - return concatted_dense[ops.reshape(mask, [-1])] - - def keras_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): - L_total = ops.shape(values_left)[0] - offsets_left_non_optional = offsets_left - if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') - offsets_right_non_optional = offsets_right - if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') - - if max_len_left is None: max_len_left_final = ops.max(offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) - else: max_len_left_final = max_len_left - if max_len_right is None: max_len_right_final = ops.max(offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) - else: max_len_right_final = max_len_right - - return keras_concat_2D_jagged_jagged(values_left=values_left, values_right=values_right, max_len_left=max_len_left_final, max_len_right=max_len_right_final, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) - - return pytorch_concat_2D_jagged_resolver(values_left=values_left, values_right=values_right, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py deleted file mode 100644 index f395004c..00000000 --- a/keras_rs/src/layers/stu.py +++ /dev/null @@ -1,231 +0,0 @@ -import abc -from typing import List, Optional, Tuple -import keras -from keras import ops -from keras import layers - -from keras_rs.src.layers.common import fx_unwrap_optional_tensor -from keras_rs.src.layers.hstu_compute_output import hstu_compute_output -from keras_rs.src.layers.hstu_uqvk_output import hstu_compute_uqvk -from keras_rs.src.layers.hstu_preprocess_attention import keras_hstu_preprocess_and_attention -from keras_rs.src.layers.hstu_mha_attention import delta_hstu_mha -from keras_rs.src.layers.jagged_tensors import split_2D_jagged, concat_2D_jagged - - -class STULayerConfig: - def __init__(self, embedding_dim: int, num_heads: int, hidden_dim: int, attention_dim: int, - output_dropout_ratio: float = 0.3, causal: bool = True, target_aware: bool = True, - max_attn_len: Optional[int] = None, attn_alpha: Optional[float] = None, - use_group_norm: bool = False, recompute_normed_x: bool = True, - recompute_uvqk: bool = True, recompute_y: bool = True, - sort_by_length: bool = True, contextual_seq_len: int = 0): - self.embedding_dim = embedding_dim - self.num_heads = num_heads - self.hidden_dim = hidden_dim - self.attention_dim = attention_dim - self.output_dropout_ratio = output_dropout_ratio - self.causal = causal - self.target_aware = target_aware - self.max_attn_len = max_attn_len - self.attn_alpha = attn_alpha - self.use_group_norm = use_group_norm - self.recompute_normed_x = recompute_normed_x - self.recompute_uvqk = recompute_uvqk - self.recompute_y = recompute_y - self.sort_by_length = sort_by_length - self.contextual_seq_len = contextual_seq_len - - -def _update_kv_cache( - max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], orig_k_cache: Optional[keras.KerasTensor], orig_v_cache: Optional[keras.KerasTensor], orig_max_kv_caching_len: int, orig_kv_caching_offsets: Optional[keras.KerasTensor], -) -> Tuple[Optional[keras.KerasTensor], Optional[keras.KerasTensor], int, Optional[keras.KerasTensor]]: - - if kv_caching_lengths is not None: - # Keras equivalent of asynchronous_complete_cumsum - kv_caching_offsets = ops.cast(ops.cumsum(kv_caching_lengths, exclusive=True), dtype="int32") - delta_offsets = seq_offsets - kv_caching_offsets - - # NOTE: split_2D_jagged is available from jagged_tensors.py - if k is not None: - k_values = ops.reshape(k, [ops.shape(k)[0], -1]) - k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=k_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) - else: - k_cache = fx_unwrap_optional_tensor(k) - if v is not None: - v_values = ops.reshape(v, [ops.shape(v)[0], -1]) - v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=v_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) - else: - v_cache = fx_unwrap_optional_tensor(v) - - if max_kv_caching_len == 0: - max_kv_caching_len = ops.convert_to_numpy(ops.cast(ops.max(kv_caching_lengths), dtype="int32")).item() - return (k_cache, v_cache, max_kv_caching_len, kv_caching_offsets) - else: - return (orig_k_cache, orig_v_cache, orig_max_kv_caching_len, orig_kv_caching_offsets) - - -def _construct_full_kv( - delta_k: keras.KerasTensor, delta_v: keras.KerasTensor, k_cache: keras.KerasTensor, v_cache: keras.KerasTensor, max_kv_caching_len: int, kv_caching_offsets: keras.KerasTensor, -) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: - L = ops.shape(delta_k)[0] - B = ops.shape(kv_caching_offsets)[0] - 1 - delta_size = L // B - - # NOTE: concat_2D_jagged is available from jagged_tensors.py - full_k = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=k_cache, values_right=delta_k, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) - full_v = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=v_cache, values_right=delta_v, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) - - # Calculate new combined offsets - delta_size_broadcast = delta_size * ops.arange(B + 1, dtype=kv_caching_offsets.dtype) - full_kv_caching_offsets = kv_caching_offsets + delta_size_broadcast - - return (full_k, full_v, max_kv_caching_len + delta_size, full_kv_caching_offsets) - - -class STU(layers.Layer, abc.ABC): - """Abstract base class for STU layers.""" - @abc.abstractmethod - def cached_forward(self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass - @abc.abstractmethod - def call(self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass - - -class STULayer(layers.Layer): - # Initialize cache properties on the instance - max_kv_caching_len: int = 0 - k_cache: Optional[keras.KerasTensor] = None - v_cache: Optional[keras.KerasTensor] = None - kv_caching_offsets: Optional[keras.KerasTensor] = None - - def __init__(self, config: STULayerConfig, is_inference: bool = False, **kwargs): - super().__init__(**kwargs) - self._config = config - self._num_heads: int = config.num_heads - self._embedding_dim: int = config.embedding_dim - self._hidden_dim: int = config.hidden_dim - self._attention_dim: int = config.attention_dim - self._output_dropout_ratio: float = config.output_dropout_ratio - self._target_aware: bool = config.target_aware - self._causal: bool = config.causal - self._max_attn_len: int = config.max_attn_len or 0 - self._attn_alpha: float = config.attn_alpha or 1.0 / (self._attention_dim**0.5) - self._use_group_norm: bool = config.use_group_norm - self._recompute_normed_x: bool = config.recompute_normed_x - self._recompute_uvqk: bool = config.recompute_uvqk - self._recompute_y: bool = config.recompute_y - self._sort_by_length: bool = config.sort_by_length - self._contextual_seq_len: int = config.contextual_seq_len - self.reset_kv_cache() - - def build(self, input_shape): - D_in = input_shape[-1] - H = self._num_heads; A = self._attention_dim; V = self._hidden_dim - output_dim_total = (V * 2 + A * 2) * H - self._uvqk_weight = self.add_weight(shape=(D_in, output_dim_total), initializer='glorot_uniform', name='uvqk_weight') - self._uvqk_beta = self.add_weight(shape=(output_dim_total,), initializer='zeros', name='uvqk_beta') - self._input_norm_weight = self.add_weight(shape=(D_in,), initializer='ones', name='input_norm_weight') - self._input_norm_bias = self.add_weight(shape=(D_in,), initializer='zeros', name='input_norm_bias') - - self._output_weight = self.add_weight(shape=(V * H, self._embedding_dim), initializer='glorot_uniform', name='output_weight') - - output_norm_shape: int = (V * H if not self._use_group_norm else H) - self._output_norm_weight = self.add_weight(shape=(output_norm_shape,), initializer='ones', name='output_norm_weight') - self._output_norm_bias = self.add_weight(shape=(output_norm_shape,), initializer='zeros', name='output_norm_bias') - self.built = True - - def reset_kv_cache(self) -> None: - self.k_cache = None; self.v_cache = None - self.kv_caching_offsets = None; self.max_kv_caching_len = 0 - - def update_kv_cache( - self, max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], - ) -> None: - # NOTE: Assumes _update_kv_cache is available - self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = _update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets) - - def construct_full_kv(self, delta_k: keras.KerasTensor, delta_v: keras.KerasTensor,) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: - # NOTE: Assumes _construct_full_kv is available - return _construct_full_kv(delta_k=delta_k, delta_v=delta_v, k_cache=fx_unwrap_optional_tensor(self.k_cache), v_cache=fx_unwrap_optional_tensor(self.v_cache), max_kv_caching_len=self.max_kv_caching_len, kv_caching_offsets=fx_unwrap_optional_tensor(self.kv_caching_offsets),) - - def call( # Standard Keras forward method - self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, - ) -> keras.KerasTensor: - - u, attn_output, k, v = keras_hstu_preprocess_and_attention( - x=x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, - num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, - uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, - max_seq_len=max_seq_len, seq_offsets=x_offsets, attn_alpha=self._attn_alpha, - causal=self._causal, num_targets=num_targets if self._target_aware else None, - max_attn_len=self._max_attn_len, contextual_seq_len=self._contextual_seq_len, - recompute_uvqk_in_backward=self._recompute_uvqk, recompute_normed_x_in_backward=self._recompute_normed_x, - sort_by_length=self._sort_by_length, prefill=kv_caching_lengths is not None, - ) - - self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=x_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) - - return hstu_compute_output( - attn=attn_output, u=u, x=x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, - norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, - group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, - concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, - ) - - def cached_forward( # Called for token-by-token generation - self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, - ) -> keras.KerasTensor: - - delta_u, delta_q, delta_k, delta_v = hstu_compute_uqvk( - x=delta_x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, - num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, - uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, - ) - - A = self._attention_dim; V = self._hidden_dim; H = self._num_heads - k_flat = ops.reshape(delta_k, [-1, H * A]) - v_flat = ops.reshape(delta_v, [-1, H * V]) - - k_full, v_full, max_seq_len, seq_offsets = self.construct_full_kv(delta_k=k_flat, delta_v=v_flat) - - self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k_full, v=v_full, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) - - # Reshape K and V back to [L_full, H, D] for attention calculation - k = ops.reshape(k_full, [-1, H, A]) - v = ops.reshape(v_full, [-1, H, V]) - - - delta_attn_output = delta_hstu_mha( - max_seq_len=max_seq_len, alpha=self._attn_alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, - num_targets=num_targets if self._target_aware else None, max_attn_len=self._max_attn_len, - contextual_seq_len=self._contextual_seq_len, - ) - - delta_attn_output = ops.reshape(delta_attn_output, [-1, V * H]) - - - return hstu_compute_output( - attn=delta_attn_output, u=delta_u, x=delta_x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, - norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, - group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, - concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, - ) - - -class STUStack(layers.Layer): - def __init__(self, stu_layers: List[STULayer], is_inference: bool = False, **kwargs): - super().__init__(**kwargs) - self._stu_layers = stu_layers - - def call( - self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, - ) -> keras.KerasTensor: - for layer in self._stu_layers: - x = layer(x=x, x_lengths=x_lengths, x_offsets=x_offsets, max_seq_len=max_seq_len, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) - return x - - def cached_forward( - self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, - ) -> keras.KerasTensor: - for layer in self._stu_layers: - delta_x = layer.cached_forward(delta_x=delta_x, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) - return delta_x diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py new file mode 100644 index 00000000..ed9db581 --- /dev/null +++ b/keras_rs/src/losses/list_mle_loss.py @@ -0,0 +1,217 @@ +from typing import Any + +import keras +from keras import ops + +from keras_rs.src import types +from keras_rs.src.api_export import keras_rs_export +from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores +from keras_rs.src.metrics.utils import standardize_call_inputs_ranks + + +@keras_rs_export("keras_rs.losses.ListMLELoss") +class ListMLELoss(keras.losses.Loss): + """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking. + + ListMLE loss is a listwise ranking loss that maximizes the likelihood of + the ground truth ranking. It works by: + 1. Sorting items by their relevance scores (labels) + 2. Computing the probability of observing this ranking given the + predicted scores + 3. Maximizing this likelihood (minimizing negative log-likelihood) + + The loss is computed as the negative log-likelihood of the ground truth + ranking given the predicted scores: + + ``` + loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i))) + ``` + + where s_i is the predicted score for item i in the sorted order. + + Args: + temperature: Temperature parameter for scaling logits. Higher values + make the probability distribution more uniform. Defaults to 1.0. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. Defaults to + `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`. + + Examples: + ```python + # Basic usage + loss_fn = ListMLELoss() + + # With temperature scaling + loss_fn = ListMLELoss(temperature=0.5) + + # Example with synthetic data + y_true = [[3, 2, 1, 0]] # Relevance scores + y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores + loss = loss_fn(y_true, y_pred) + ``` + """ + + def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if temperature <= 0.0: + raise ValueError( + f"`temperature` should be a positive float. Received: " + f"`temperature` = {temperature}." + ) + + self.temperature = temperature + self._epsilon = 1e-10 + + def compute_unreduced_loss( + self, + labels: types.Tensor, + logits: types.Tensor, + mask: types.Tensor | None = None, + ) -> tuple[types.Tensor, types.Tensor]: + """Compute the unreduced ListMLE loss. + + Args: + labels: Ground truth relevance scores of + shape [batch_size,list_size]. + logits: Predicted scores of shape [batch_size, list_size]. + mask: Optional mask of shape [batch_size, list_size]. + + Returns: + Tuple of (losses, weights) where losses has shape [batch_size, 1] + and weights has the same shape. + """ + + valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype)) + + if mask is not None: + valid_mask = ops.logical_and( + valid_mask, ops.cast(mask, dtype="bool") + ) + + num_valid_items = ops.sum( + ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True + ) + + batch_has_valid_items = ops.greater(num_valid_items, 0.0) + + labels_for_sorting = ops.where( + valid_mask, labels, ops.full_like(labels, -1e9) + ) + logits_masked = ops.where( + valid_mask, logits, ops.full_like(logits, -1e9) + ) + sorted_logits, sorted_valid_mask = sort_by_scores( + tensors_to_sort=[logits_masked, valid_mask], + scores=labels_for_sorting, + mask=None, + shuffle_ties=False, + seed=None, + ) + sorted_logits = ops.divide( + sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype) + ) + + valid_logits_for_max = ops.where( + sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) + ) + raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True) + raw_max = ops.where( + batch_has_valid_items, raw_max, ops.zeros_like(raw_max) + ) + sorted_logits = ops.subtract(sorted_logits, raw_max) + + # Set invalid positions to very negative BEFORE exp + sorted_logits = ops.where( + sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) + ) + exp_logits = ops.exp(sorted_logits) + + # reversed_exp = ops.flip(exp_logits, axis=1) + # reversed_cumsum = ops.cumsum(reversed_exp, axis=1) + # cumsum_from_right = ops.flip(reversed_cumsum, axis=1) + # cumsum_forward = ops.cumsum(exp_logits, axis=1) + # total_sum = ops.sum(exp_logits, axis=1, keepdims=True) + # cumsum_from_right = total_sum - cumsum_forward + exp_logits + reversed_exp = ops.flip(exp_logits, axis=1) + reversed_cumsum = ops.cumsum(reversed_exp, axis=1) + cumsum_from_right = ops.flip(reversed_cumsum, axis=1) + + log_normalizers = ops.log(cumsum_from_right + self._epsilon) + log_probs = ops.subtract(sorted_logits, log_normalizers) + + log_probs = ops.where( + sorted_valid_mask, log_probs, ops.zeros_like(log_probs) + ) + + negative_log_likelihood = ops.negative( + ops.sum(log_probs, axis=1, keepdims=True) + ) + + negative_log_likelihood = ops.where( + batch_has_valid_items, + negative_log_likelihood, + ops.zeros_like(negative_log_likelihood), + ) + + weights = ops.ones_like(negative_log_likelihood) + + return negative_log_likelihood, weights + + def call( + self, + y_true: types.Tensor, + y_pred: types.Tensor, + ) -> types.Tensor: + """Compute the ListMLE loss. + + Args: + y_true: tensor or dict. Ground truth values. If tensor, of shape + `(list_size)` for unbatched inputs or `(batch_size, list_size)` + for batched inputs. If an item has a label of -1, it is ignored + in loss computation. If it is a dictionary, it should have two + keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore + elements in loss computation. + y_pred: tensor. The predicted values, of shape `(list_size)` for + unbatched inputs or `(batch_size, list_size)` for batched + inputs. Should be of the same shape as `y_true`. + + Returns: + The loss tensor of shape [batch_size]. + """ + mask = None + if isinstance(y_true, dict): + if "labels" not in y_true: + raise ValueError( + '`"labels"` should be present in `y_true`. Received: ' + f"`y_true` = {y_true}" + ) + + mask = y_true.get("mask", None) + y_true = y_true["labels"] + + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + if mask is not None: + mask = ops.convert_to_tensor(mask) + + y_true, y_pred, mask, _ = standardize_call_inputs_ranks( + y_true, y_pred, mask + ) + + losses, weights = self.compute_unreduced_loss( + labels=y_true, logits=y_pred, mask=mask + ) + losses = ops.multiply(losses, weights) + losses = ops.squeeze(losses, axis=-1) + return losses + + # getting config + def get_config(self) -> dict[str, Any]: + config: dict[str, Any] = super().get_config() + config.update({"temperature": self.temperature}) + return config diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py new file mode 100644 index 00000000..3656354b --- /dev/null +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -0,0 +1,99 @@ +import keras +from absl.testing import parameterized +from keras import ops +from keras.losses import deserialize +from keras.losses import serialize + +from keras_rs.src import testing +from keras_rs.src.losses.list_mle_loss import ListMLELoss + + +class ListMLELossTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + self.unbatched_scores = ops.array( + [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" + ) + self.unbatched_labels = ops.array( + [1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32" + ) + self.batched_scores = ops.array( + [[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], + dtype="float32", + ) + self.batched_labels = ops.array( + [[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], + dtype="float32", + ) + self.expected_output = ops.array([6.865693, 3.088192], dtype="float32") + + def test_unbatched_input(self): + loss = ListMLELoss(reduction="none") + output = loss( + y_true=self.unbatched_labels, y_pred=self.unbatched_scores + ) + self.assertEqual(output.shape, (1,)) + self.assertTrue(ops.convert_to_numpy(output[0]) > 0) + self.assertAllClose(output, [self.expected_output[0]], atol=1e-5) + + def test_batched_input(self): + loss = ListMLELoss(reduction="none") + output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) + self.assertEqual(output.shape, (2,)) + self.assertTrue(ops.convert_to_numpy(output[0]) > 0) + self.assertTrue(ops.convert_to_numpy(output[1]) > 0) + self.assertAllClose(output, self.expected_output, atol=1e-5) + + def test_temperature(self): + loss_temp = ListMLELoss(temperature=0.5, reduction="none") + output_temp = loss_temp( + y_true=self.batched_labels, y_pred=self.batched_scores + ) + self.assertAllClose( + output_temp, + [10.969891, 2.1283305], + atol=1e-5, + ) + + def test_invalid_input_rank(self): + rank_1_input = ops.ones((2, 3, 4)) + + loss = ListMLELoss() + with self.assertRaises(ValueError): + loss(y_true=rank_1_input, y_pred=rank_1_input) + + def test_loss_reduction(self): + loss = ListMLELoss(reduction="sum_over_batch_size") + output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) + self.assertAlmostEqual( + ops.convert_to_numpy(output), 4.9769425, places=5 + ) + + def test_scalar_sample_weight(self): + sample_weight = ops.array(5.0) + loss = ListMLELoss(reduction="none") + + output = loss( + y_true=self.batched_labels, + y_pred=self.batched_scores, + sample_weight=sample_weight, + ) + + self.assertAllClose( + output, self.expected_output * sample_weight, atol=1e-5 + ) + + def test_model_fit(self): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile(loss=ListMLELoss(), optimizer="adam") + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), + ) + + def test_serialization(self): + loss = ListMLELoss(temperature=0.8) + restored = deserialize(serialize(loss)) + self.assertDictEqual(loss.get_config(), restored.get_config()) From b64377808f90b1dfb828598a5fad2c65a5e433a6 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Mon, 27 Oct 2025 18:29:31 +0530 Subject: [PATCH 11/19] Debug statements added --- keras_rs/src/losses/list_mle_loss.py | 44 +++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py index ed9db581..914aba9e 100644 --- a/keras_rs/src/losses/list_mle_loss.py +++ b/keras_rs/src/losses/list_mle_loss.py @@ -55,7 +55,9 @@ class ListMLELoss(keras.losses.Loss): ``` """ - def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: + def __init__( + self, temperature: float = 1.0, debug: bool = True, **kwargs: Any + ) -> None: super().__init__(**kwargs) if temperature <= 0.0: @@ -66,6 +68,7 @@ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: self.temperature = temperature self._epsilon = 1e-10 + self.debug = debug def compute_unreduced_loss( self, @@ -131,12 +134,6 @@ def compute_unreduced_loss( ) exp_logits = ops.exp(sorted_logits) - # reversed_exp = ops.flip(exp_logits, axis=1) - # reversed_cumsum = ops.cumsum(reversed_exp, axis=1) - # cumsum_from_right = ops.flip(reversed_cumsum, axis=1) - # cumsum_forward = ops.cumsum(exp_logits, axis=1) - # total_sum = ops.sum(exp_logits, axis=1, keepdims=True) - # cumsum_from_right = total_sum - cumsum_forward + exp_logits reversed_exp = ops.flip(exp_logits, axis=1) reversed_cumsum = ops.cumsum(reversed_exp, axis=1) cumsum_from_right = ops.flip(reversed_cumsum, axis=1) @@ -160,6 +157,39 @@ def compute_unreduced_loss( weights = ops.ones_like(negative_log_likelihood) + # Debug print statements for all intermediate values + if self.debug: + import sys + + def safe_print(label, value): + try: + # For TensorFlow, only print numpy if in eager mode + if hasattr(value, "numpy"): + print(label, value.numpy(), file=sys.stderr) + else: + print( + label, ops.convert_to_numpy(value), file=sys.stderr + ) + except Exception as e: + print(label, f"", file=sys.stderr) + + safe_print("valid_mask", valid_mask) + safe_print("num_valid_items", num_valid_items) + safe_print("batch_has_valid_items", batch_has_valid_items) + safe_print("labels_for_sorting", labels_for_sorting) + safe_print("logits_masked", logits_masked) + safe_print("sorted_logits", sorted_logits) + safe_print("sorted_valid_mask", sorted_valid_mask) + safe_print("raw_max", raw_max) + safe_print("exp_logits", exp_logits) + safe_print("reversed_exp", reversed_exp) + safe_print("reversed_cumsum", reversed_cumsum) + safe_print("cumsum_from_right", cumsum_from_right) + safe_print("log_normalizers", log_normalizers) + safe_print("log_probs", log_probs) + safe_print("negative_log_likelihood", negative_log_likelihood) + safe_print("weights", weights) + return negative_log_likelihood, weights def call( From 3f6418c0da6db2f1b05729c02ae7ccbe553efdf6 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Wed, 29 Oct 2025 18:13:55 +0530 Subject: [PATCH 12/19] pytest added --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..43a8bce0 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -44,7 +44,7 @@ jobs: pip install --no-deps -e "." --progress-bar off - name: Test with pytest run: | - pytest keras_rs/ + pytest -s . run_tests_in_container: name: Test the code on TPU From 47f0251e3203dd837993df8ee19cbe09ebe97793 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Wed, 29 Oct 2025 20:11:02 +0530 Subject: [PATCH 13/19] Added_stable offset code to sorting_labels --- keras_rs/src/losses/list_mle_loss.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py index 914aba9e..62c6ce40 100644 --- a/keras_rs/src/losses/list_mle_loss.py +++ b/keras_rs/src/losses/list_mle_loss.py @@ -108,6 +108,19 @@ def compute_unreduced_loss( logits_masked = ops.where( valid_mask, logits, ops.full_like(logits, -1e9) ) + + # Here the issue is that pytorch is not stable in sorting, so + # added stable offset before calling sort_by_scores + list_size = ops.shape(labels_for_sorting)[1] + indices = ops.arange(list_size) + + indices = ops.expand_dims(indices, axis=0) + indices = ops.broadcast_to(indices, ops.shape(labels_for_sorting)) + + stable_offset = ops.cast(indices, labels_for_sorting.dtype) * 1e-6 + + labels_for_sorting = ops.subtract(labels_for_sorting, stable_offset) + sorted_logits, sorted_valid_mask = sort_by_scores( tensors_to_sort=[logits_masked, valid_mask], scores=labels_for_sorting, From a3679c619bada9d19d7efc60a235559763621b2b Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Wed, 29 Oct 2025 20:57:51 +0530 Subject: [PATCH 14/19] Stable_offset code to clear the ambiguity with similar labels --- .github/workflows/actions.yml | 2 +- keras_rs/src/losses/list_mle_loss.py | 38 ++-------------------------- 2 files changed, 3 insertions(+), 37 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 43a8bce0..1febc91e 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -44,7 +44,7 @@ jobs: pip install --no-deps -e "." --progress-bar off - name: Test with pytest run: | - pytest -s . + pytest keras_rs/ run_tests_in_container: name: Test the code on TPU diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py index 62c6ce40..68ee420a 100644 --- a/keras_rs/src/losses/list_mle_loss.py +++ b/keras_rs/src/losses/list_mle_loss.py @@ -56,7 +56,7 @@ class ListMLELoss(keras.losses.Loss): """ def __init__( - self, temperature: float = 1.0, debug: bool = True, **kwargs: Any + self, temperature: float = 1.0, **kwargs: Any ) -> None: super().__init__(**kwargs) @@ -68,7 +68,6 @@ def __init__( self.temperature = temperature self._epsilon = 1e-10 - self.debug = debug def compute_unreduced_loss( self, @@ -108,8 +107,6 @@ def compute_unreduced_loss( logits_masked = ops.where( valid_mask, logits, ops.full_like(logits, -1e9) ) - - # Here the issue is that pytorch is not stable in sorting, so # added stable offset before calling sort_by_scores list_size = ops.shape(labels_for_sorting)[1] indices = ops.arange(list_size) @@ -170,38 +167,7 @@ def compute_unreduced_loss( weights = ops.ones_like(negative_log_likelihood) - # Debug print statements for all intermediate values - if self.debug: - import sys - - def safe_print(label, value): - try: - # For TensorFlow, only print numpy if in eager mode - if hasattr(value, "numpy"): - print(label, value.numpy(), file=sys.stderr) - else: - print( - label, ops.convert_to_numpy(value), file=sys.stderr - ) - except Exception as e: - print(label, f"", file=sys.stderr) - - safe_print("valid_mask", valid_mask) - safe_print("num_valid_items", num_valid_items) - safe_print("batch_has_valid_items", batch_has_valid_items) - safe_print("labels_for_sorting", labels_for_sorting) - safe_print("logits_masked", logits_masked) - safe_print("sorted_logits", sorted_logits) - safe_print("sorted_valid_mask", sorted_valid_mask) - safe_print("raw_max", raw_max) - safe_print("exp_logits", exp_logits) - safe_print("reversed_exp", reversed_exp) - safe_print("reversed_cumsum", reversed_cumsum) - safe_print("cumsum_from_right", cumsum_from_right) - safe_print("log_normalizers", log_normalizers) - safe_print("log_probs", log_probs) - safe_print("negative_log_likelihood", negative_log_likelihood) - safe_print("weights", weights) + return negative_log_likelihood, weights From 22757fdd84c541a843b48d591446b8e8a08cced2 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Wed, 29 Oct 2025 21:57:18 +0530 Subject: [PATCH 15/19] Stable_offset code added to handle sorting_label --- keras_rs/src/losses/list_mle_loss.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py index 68ee420a..f9b147aa 100644 --- a/keras_rs/src/losses/list_mle_loss.py +++ b/keras_rs/src/losses/list_mle_loss.py @@ -55,9 +55,7 @@ class ListMLELoss(keras.losses.Loss): ``` """ - def __init__( - self, temperature: float = 1.0, **kwargs: Any - ) -> None: + def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: super().__init__(**kwargs) if temperature <= 0.0: @@ -167,8 +165,6 @@ def compute_unreduced_loss( weights = ops.ones_like(negative_log_likelihood) - - return negative_log_likelihood, weights def call( From 35bfa0df4d2d56ea0b03e26dd9357b5d8ae26153 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Thu, 30 Oct 2025 12:55:23 +0530 Subject: [PATCH 16/19] Save local changes --- .github/workflows/actions.yml | 123 ------------------ keras_rs/src/losses/list_mle_loss.py | 10 -- keras_rs/src/metrics/ranking_metrics_utils.py | 20 ++- 3 files changed, 19 insertions(+), 134 deletions(-) delete mode 100644 .github/workflows/actions.yml diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml deleted file mode 100644 index 1febc91e..00000000 --- a/.github/workflows/actions.yml +++ /dev/null @@ -1,123 +0,0 @@ -name: Tests - -on: - push: - pull_request: - workflow_call: - release: - types: [created] - -permissions: - contents: read - -jobs: - run_tests: - name: Test the code - strategy: - fail-fast: false - matrix: - backend: [tensorflow, jax, torch] - runs-on: ubuntu-latest - env: - KERAS_BACKEND: ${{ matrix.backend }} - steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip setuptools - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@v4 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - name: Install dependencies - run: | - pip install -r requirements.txt --progress-bar off - pip install --no-deps -e "." --progress-bar off - - name: Test with pytest - run: | - pytest keras_rs/ - - run_tests_in_container: - name: Test the code on TPU - runs-on: linux-x86-ct6e-44-1tpu - - strategy: - fail-fast: false - matrix: - backend: [tensorflow, jax] - - container: - image: python:3.11-slim - options: --privileged --network host - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install Dependencies - run: | - pip install --no-cache-dir -U pip && \ - pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt - - - name: Set Keras Backend - run: | - echo "KERAS_BACKEND=${{ matrix.backend }}" >> $GITHUB_ENV - echo "TPU_NAME=local" >> $GITHUB_ENV - - - name: Set TF Specific Environment Variables - if: ${{ matrix.backend == 'tensorflow'}} - run: | - echo "PJRT_DEVICE=TPU" >> $GITHUB_ENV - echo "NEXT_PLUGGABLE_DEVICE_USE_C_API=true" >> $GITHUB_ENV - echo "TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge=true" >> $GITHUB_ENV - pip show libtpu | grep "^Location: " | sed "s/^Location: \(.*\)$/TF_PLUGGABLE_DEVICE_LIBRARY_PATH=\1\/libtpu\/libtpu.so/1" >> $GITHUB_ENV - - - name: Verify TF Installation - if: ${{ matrix.backend == 'tensorflow'}} - run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices())" - - - name: Verify JAX Installation - if: ${{ matrix.backend == 'jax'}} - run: python3 -c "import jax; print('JAX devices:', jax.devices())" - - - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py - - check_format: - name: Check the code format - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Set up Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip setuptools - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - name: pip cache - uses: actions/cache@v4 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} - restore-keys: | - ${{ runner.os }}-pip- - - name: Install dependencies - run: | - pip install -r requirements.txt --progress-bar off - pip install --no-deps -e "." --progress-bar off - - name: Install pre-commit - run: pip install pre-commit && pre-commit install - - name: Run pre-commit - run: pre-commit run --all-files --hook-stage manual diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py index f9b147aa..5d348ea8 100644 --- a/keras_rs/src/losses/list_mle_loss.py +++ b/keras_rs/src/losses/list_mle_loss.py @@ -105,16 +105,6 @@ def compute_unreduced_loss( logits_masked = ops.where( valid_mask, logits, ops.full_like(logits, -1e9) ) - # added stable offset before calling sort_by_scores - list_size = ops.shape(labels_for_sorting)[1] - indices = ops.arange(list_size) - - indices = ops.expand_dims(indices, axis=0) - indices = ops.broadcast_to(indices, ops.shape(labels_for_sorting)) - - stable_offset = ops.cast(indices, labels_for_sorting.dtype) * 1e-6 - - labels_for_sorting = ops.subtract(labels_for_sorting, stable_offset) sorted_logits, sorted_valid_mask = sort_by_scores( tensors_to_sort=[logits_masked, valid_mask], diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index f969489a..dc519f75 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -84,7 +84,25 @@ def sort_by_scores( k = min(k, max_possible_k) else: k = ops.minimum(k, max_possible_k) - + + # --- Work around for PyTorch instability --- + # Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF. + # See: + # - https://github.com/pytorch/pytorch/issues/27542 + # - https://github.com/pytorch/pytorch/issues/88227 + # + # This small "stable offset" ensures deterministic tie-breaking for equal scores. + # We can remove this workaround once PyTorch adds a `stable=True` flag for topk. + + if backend.backend() == "torch" and not shuffle_ties: + list_size = ops.shape(scores)[1] + indices = ops.arange(list_size) + indices = ops.expand_dims(indices, axis=0) + indices = ops.broadcast_to(indices, ops.shape(scores)) + stable_offset = ops.cast(indices, scores.dtype) * 1e-6 + scores = ops.subtract(scores, stable_offset) + # --- End FIX --- + # Shuffle ties randomly, and push masked values to the beginning. shuffled_indices = None if shuffle_ties or mask is not None: From 84244162f55836b243131f40303f4ea0f44af680 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Thu, 30 Oct 2025 13:12:35 +0530 Subject: [PATCH 17/19] Updated few code changes --- .github/workflows/actions.yml | 123 ++++++++++++++++++ keras_rs/src/metrics/ranking_metrics_utils.py | 9 +- 2 files changed, 128 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/actions.yml diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml new file mode 100644 index 00000000..5088f571 --- /dev/null +++ b/.github/workflows/actions.yml @@ -0,0 +1,123 @@ +name: Tests + +on: + push: + pull_request: + workflow_call: + release: + types: [created] + +permissions: + contents: read + +jobs: + run_tests: + name: Test the code + strategy: + fail-fast: false + matrix: + backend: [tensorflow, jax, torch] + runs-on: ubuntu-latest + env: + KERAS_BACKEND: ${{ matrix.backend }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off + pip install --no-deps -e "." --progress-bar off + - name: Test with pytest + run: | + pytest keras_rs/ + + run_tests_in_container: + name: Test the code on TPU + runs-on: linux-x86-ct6e-44-1tpu + + strategy: + fail-fast: false + matrix: + backend: [tensorflow, jax] + + container: + image: python:3.11-slim + options: --privileged --network host + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Dependencies + run: | + pip install --no-cache-dir -U pip && \ + pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt + + - name: Set Keras Backend + run: | + echo "KERAS_BACKEND=${{ matrix.backend }}" >> $GITHUB_ENV + echo "TPU_NAME=local" >> $GITHUB_ENV + + - name: Set TF Specific Environment Variables + if: ${{ matrix.backend == 'tensorflow'}} + run: | + echo "PJRT_DEVICE=TPU" >> $GITHUB_ENV + echo "NEXT_PLUGGABLE_DEVICE_USE_C_API=true" >> $GITHUB_ENV + echo "TF_XLA_FLAGS=--tf_mlir_enable_mlir_bridge=true" >> $GITHUB_ENV + pip show libtpu | grep "^Location: " | sed "s/^Location: \(.*\)$/TF_PLUGGABLE_DEVICE_LIBRARY_PATH=\1\/libtpu\/libtpu.so/1" >> $GITHUB_ENV + + - name: Verify TF Installation + if: ${{ matrix.backend == 'tensorflow'}} + run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices())" + + - name: Verify JAX Installation + if: ${{ matrix.backend == 'jax'}} + run: python3 -c "import jax; print('JAX devices:', jax.devices())" + + - name: Test with pytest + run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + + check_format: + name: Check the code format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off + pip install --no-deps -e "." --progress-bar off + - name: Install pre-commit + run: pip install pre-commit && pre-commit install + - name: Run pre-commit + run: pre-commit run --all-files --hook-stage manual diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index dc519f75..200f8966 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -1,6 +1,7 @@ from typing import Callable import keras +from keras import backend as K from keras import ops from keras_rs.src import types @@ -84,7 +85,7 @@ def sort_by_scores( k = min(k, max_possible_k) else: k = ops.minimum(k, max_possible_k) - + # --- Work around for PyTorch instability --- # Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF. # See: @@ -93,8 +94,8 @@ def sort_by_scores( # # This small "stable offset" ensures deterministic tie-breaking for equal scores. # We can remove this workaround once PyTorch adds a `stable=True` flag for topk. - - if backend.backend() == "torch" and not shuffle_ties: + + if K.backend() == "torch" and not shuffle_ties: list_size = ops.shape(scores)[1] indices = ops.arange(list_size) indices = ops.expand_dims(indices, axis=0) @@ -102,7 +103,7 @@ def sort_by_scores( stable_offset = ops.cast(indices, scores.dtype) * 1e-6 scores = ops.subtract(scores, stable_offset) # --- End FIX --- - + # Shuffle ties randomly, and push masked values to the beginning. shuffled_indices = None if shuffle_ties or mask is not None: From 7017f352079b34d41730732b35ded55f777003e6 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Thu, 30 Oct 2025 13:43:08 +0530 Subject: [PATCH 18/19] lint errors corrected --- keras_rs/src/metrics/ranking_metrics_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index 200f8966..90d1c78a 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -92,8 +92,9 @@ def sort_by_scores( # - https://github.com/pytorch/pytorch/issues/27542 # - https://github.com/pytorch/pytorch/issues/88227 # - # This small "stable offset" ensures deterministic tie-breaking for equal scores. - # We can remove this workaround once PyTorch adds a `stable=True` flag for topk. + # This small "stable offset" ensures deterministic tie-breaking for + # equal scores. We can remove this workaround once PyTorch adds a + # `stable=True` flag for topk. if K.backend() == "torch" and not shuffle_ties: list_size = ops.shape(scores)[1] From 7044610ddf9cb6210483f20ef8a7f59e0abff523 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Thu, 30 Oct 2025 21:40:43 +0530 Subject: [PATCH 19/19] deleted keras backend import statement --- keras_rs/src/metrics/ranking_metrics_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras_rs/src/metrics/ranking_metrics_utils.py b/keras_rs/src/metrics/ranking_metrics_utils.py index 90d1c78a..ea3ddaaf 100644 --- a/keras_rs/src/metrics/ranking_metrics_utils.py +++ b/keras_rs/src/metrics/ranking_metrics_utils.py @@ -1,7 +1,6 @@ from typing import Callable import keras -from keras import backend as K from keras import ops from keras_rs.src import types @@ -96,7 +95,7 @@ def sort_by_scores( # equal scores. We can remove this workaround once PyTorch adds a # `stable=True` flag for topk. - if K.backend() == "torch" and not shuffle_ties: + if keras.backend.backend() == "torch" and not shuffle_ties: list_size = ops.shape(scores)[1] indices = ops.arange(list_size) indices = ops.expand_dims(indices, axis=0)