Skip to content

feat(pymllm, gemma3n): Add Gemma3n text-only native server path#672

Open
Grape203 wants to merge 12 commits intoUbiquitousLearning:mainfrom
Grape203:pr/gemma3n-text-only-native-20260501_135528
Open

feat(pymllm, gemma3n): Add Gemma3n text-only native server path#672
Grape203 wants to merge 12 commits intoUbiquitousLearning:mainfrom
Grape203:pr/gemma3n-text-only-native-20260501_135528

Conversation

@Grape203
Copy link
Copy Markdown
Contributor

@Grape203 Grape203 commented May 1, 2026

Summary

This PR adds an initial Gemma3n text-only native path to pymllm.

The implementation focuses on the simplest text-only LLM path first, following the staged direction discussed with the maintainer. It supports loading Gemma3n E2B text weights, native text forward/generation, and basic pymllm-server decode with RadixCache disabled.

Main changes

  • Add pymllm.models.gemma3n with a text-only native Gemma3n implementation.
  • Register Gemma3n model classes in pymllm.models.
  • Add model-specific CPU-first weight loading support for Gemma3n.
  • Add Gemma3n text weight streaming loader.
  • Add text-only support for per-layer embedding, AltUp flow, sliding/full attention layer types, and KV sharing.
  • Fix the disable_radix_cache path so ChunkCache does not enter RadixCache-specific insertion logic.
  • Add a minimal batch=1 full-context recompute path for native server decode correctness.

Verification

Tested locally with Gemma3n E2B text weights.

  • Native text weight loading succeeds:
    • loaded=732
    • skipped=825
    • missing_in_ckpt=0
  • Native direct generation works.
  • pymllm-server works with RadixCache disabled.
  • /v1/completions multi-token decode returns:
The capital of France is **Paris**.

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Added Gemma 3n model support (causal and conditional runtime paths).

* **Improvements**
  * Optional CPU-first model instantiation to improve loading stability.
  * Model-aware streaming weight loading with safe fallback and post-load placement on CUDA.
  * Attention backend now honors sliding-window size.
  * Model registry updated to recognize Gemma3n architectures.
  * Radix-cache logic refined to skip radix-specific ops for non-radix caches.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a full Gemma 3n text-model implementation and registry entries, updates layer exports, changes model instantiation to optionally perform CPU-first instantiation and prefer model-provided weight loaders, passes sliding-window size into the FlashInfer attention backend, and restricts radix-cache-specific logic to caches that expose radix semantics. (50 words)

Changes

Gemma 3n model & registry

Layer / File(s) Summary
Data / config helpers
pymllm/models/gemma3n.py
Adds _get_text_config, _get_layer_types, _get_hidden_act_fn, _get_intermediate_size.
Primitives & RoPE
pymllm/models/gemma3n.py, pymllm/layers/gemma3n.py
Adds RoPE helpers (_rotate_half, _build_rope_cos_sin, _apply_rope), and Gemma3n RMSNorm primitives (Gemma3nRMSNorm, Gemma3nRMSNormNoWeight).
Core modules
pymllm/models/gemma3n.py
Adds Gemma3nScaledWordEmbedding, Gemma3nLaurelBlock, Gemma3nAltUp, and Gemma3nMLP wiring.
Attention & decoder
pymllm/models/gemma3n.py
Adds Gemma3nAttention (layer RoPE, causal/sliding masks, optional RadixAttention routing, KV sharing/GQA handling) and Gemma3nDecoderLayer.
Model / runner wrappers
pymllm/models/gemma3n.py
Adds Gemma3nModel (shared KV cache, move_compute_modules_to_device) and Gemma3nForCausalLM (flag requires_cpu_first_weight_loading, recompute decode path, load_weights_from_model_path, load_weights, lm_head tying) plus Gemma3nForConditionalGeneration.
Registry
pymllm/models/__init__.py
Registers HF architectures values to resolve to Gemma3nForCausalLM / Gemma3nForConditionalGeneration lazily.

Model instantiation, weight loading & attention wiring

Layer / File(s) Summary
Instantiation decision
pymllm/executor/model_runner.py
load_model inspects model_cls.requires_cpu_first_weight_loading and can instantiate the model on CPU first when true; temporarily sets torch.set_default_dtype(self.dtype) during instantiation and restores it.
Weight loading strategy
pymllm/executor/model_runner.py
After instantiation, prefers model.load_weights_from_model_path(model_path) when present and use_model_path_weight_loader (bool or callable) resolves truthy; otherwise falls back to model.load_weights(self._iter_weights(model_path)).
Post-load device move
pymllm/executor/model_runner.py
After weights load, if model defines callable move_compute_modules_to_device and runtime device is CUDA, calls it to relocate compute modules to the runtime device.
Attention backend wiring
pymllm/executor/model_runner.py
init_attention_backend now passes sliding_window_size=self.sliding_window_size into FlashInferAttnBackend.

Cache handling guards

Layer / File(s) Summary
Radix guard checks
pymllm/orchestrator/model_runner_process.py
Radix-cache-specific insert/cleanup code paths now run only when the cache implementation exposes radix semantics (hasattr(cache, "page_size")); _insert_into_radix_cache is a no-op for non-radix caches.
Free / unlock logic
pymllm/orchestrator/model_runner_process.py
_free_rid_resources updates cache_enabled condition and avoids radix-only unlock/rematch branches for non-radix caches while preserving shared _radix_cache handling.

Sequence Diagram(s)

sequenceDiagram
    participant Runner as ModelRunner
    participant Model as ModelClass
    participant Storage as CheckpointStorage
    participant Device as DeviceMgr

    Runner->>Model: inspect requires_cpu_first_weight_loading
    alt CPU-first required
        Runner->>Device: set instantiate_device="cpu"
        Runner-->>Runner: temporarily set default dtype
    else
        Runner->>Device: set instantiate_device=runtime_device
    end

    Runner->>Model: instantiate model on instantiate_device
    Runner->>Model: evaluate use_model_path_weight_loader (bool/callable)
    alt model provides load_weights_from_model_path and resolver true
        Runner->>Storage: request model_path streaming
        Storage-->>Model: stream checkpoint path/data
        Model->>Model: load_weights_from_model_path(model_path)
    else
        Runner->>Storage: iterate weight tensors
        Storage-->>Runner: weight iterator
        Runner->>Model: load_weights(weight_iterator)
    end

    alt runtime device is CUDA and model has move_compute_modules_to_device
        Runner->>Model: move_compute_modules_to_device(runtime_cuda_device)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • UbiquitousLearning/mllm#640: Modifies similar internal package files including pymllm/layers/__init__.py and orchestrator/model_runner_process.py, indicating overlapping code-level concerns.

Suggested reviewers

  • oreomaker
  • yirongjie
  • chenghuaWang
  • xumengwei

Poem

🐰 I hopped through RoPE and ALTUP vines,
CPU-first seeds and streaming lines.
Weights find burrows, caches mind the trails,
Gemma grows green in tokened vales.
🥕 Hop, compile, and shipping scales.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.28% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main objective: adding a Gemma3n text-only native server path to pymllm.
Description check ✅ Passed The description provides a comprehensive summary, main changes, and verification results aligned with PR objectives, though it deviates from the minimal template provided.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-727: The code infers prefill vs decode from sequence length
and cache emptiness which lets a fresh 1-token request reuse stale context;
update the logic in the native path (around input_ids_hf/position_ids_hf
handling and the is_prefill decision) to detect request boundaries and reset
cached state when a new decode begins: if forward_batch.forward_mode indicates
"decode" (or if position_ids_hf[...,0] == 0) then clear
self._native_cached_input_ids, self._native_cached_positions and related cached
tensors such as self._hf_past_key_values before computing is_prefill so a
single-token new request does not append to prior context (apply same change to
the other similar branch handling caches).
- Around line 1041-1046: The weights handling in ModelRunner.load_model()
materializes iterables with list(weights), which will drain generator-like
inputs such as self._iter_weights and cause OOM for large Gemma3n checkpoints;
change the logic to treat non-dict weights as an iterable and iterate over it
directly (e.g., use a for-loop over weights or assign weight_items = weights
when it's already an iterable) rather than calling list(weights); if you keep
the TypeError branch for non-iterables, re-raise or raise a new error using
"from err" to preserve exception chaining.
- Around line 939-946: The current Gemma3n streaming loader returns
self.load_weights([]) when no .safetensors are found, which silently yields an
uninitialized model; instead, change the behavior in the st_files check inside
the Gemma3n loader: either raise a clear exception (e.g., FileNotFoundError or
ValueError) with a message that no .safetensors were found for the given
model_path so callers (including ModelRunner.load_model when
use_model_path_weight_loader is enabled) fail fast, or implement a proper .bin
fallback by detecting "*.bin" files and delegating to the existing .bin loading
code (call the appropriate bin-weight loader method) rather than returning an
empty load_weights list; update references to load_weights and the loader
selection logic accordingly.

In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 501-504: The cleanup path incorrectly uses a simple "cache is not
None" check so ChunkCache (which lacks page_size) can still run radix cleanup
and set did_insert=True, preventing KV frees; update the cleanup logic to mirror
the insertion guard by checking for radix-capable cache (e.g., "if cache is not
None and hasattr(cache, 'page_size')" or equivalent) before calling
_free_rid_resources or setting did_insert, and ensure _free_rid_resources only
treats did_insert as true when a real radix insertion occurred (adjust
_free_rid_resources and caller logic in model_runner_process.py to use the same
hasattr(cache, 'page_size') predicate to decide insertion/cleanup and never mark
did_insert for ChunkCache).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 63329a33-6882-4fdc-bc8e-b9cf3b71fcd7

📥 Commits

Reviewing files that changed from the base of the PR and between 729ca4c and 2126cae.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py

Comment thread pymllm/models/gemma3n.py Outdated
Comment thread pymllm/models/gemma3n.py Outdated
Comment thread pymllm/orchestrator/model_runner_process.py
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from 2126cae to b400087 Compare May 1, 2026 09:03
@Grape203
Copy link
Copy Markdown
Contributor Author

Grape203 commented May 1, 2026

Updated the PR to address the CodeRabbit comments:

  • reset native decode cache based on extend/prefill boundary
  • raise FileNotFoundError when safetensors weights are missing
  • avoid materializing streaming weight iterators in load_weights
  • avoid treating ChunkCache as RadixCache during cleanup

Also re-ran the local pymllm-server verification with Gemma3n E2B text weights and /v1/completions still returns:
"The capital of France is Paris."

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
pymllm/models/gemma3n.py (2)

1062-1068: ⚡ Quick win

Chain the exception for better debugging context.

The TypeError raised when weights is not iterable should be chained with the original exception using from err to preserve the exception context and aid debugging.

♻️ Proposed fix
         else:
             try:
                 weight_items = iter(weights)
-            except TypeError:
+            except TypeError as err:
                 raise TypeError(
                     f"weights must be a dict-like state_dict, a module with state_dict(), "
                     f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+                ) from err
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 1062 - 1068, The TypeError raised when
attempting to call iter(weights) loses the original exception context; update
the try/except around iter(weights) to capture the original exception (e.g.,
except TypeError as err) and re-raise the new TypeError using "from err" so the
stack trace is chained and debugging shows the original error; locate the
try/except that wraps iter(weights) in the weights handling logic and modify
that raise to include the exception chaining.

356-359: 💤 Low value

Minor style: getattr with constant attribute after hasattr check.

Line 359 uses getattr(forward_batch, "kv_shared_cache") immediately after checking hasattr(forward_batch, "kv_shared_cache") on line 358. This can be simplified to direct attribute access since the attribute's existence is already verified.

♻️ Suggested simplification
-        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
-            shared_kv_cache = getattr(forward_batch, "kv_shared_cache")
+        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
+            shared_kv_cache = forward_batch.kv_shared_cache
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 356 - 359, The code checks
hasattr(forward_batch, "kv_shared_cache") then uses getattr(forward_batch,
"kv_shared_cache"); replace the getattr call with direct attribute access
(forward_batch.kv_shared_cache) to simplify the style while preserving
behavior—update the block that assigns shared_kv_cache from forward_batch in the
method where forward_batch is used (the snippet using isinstance(forward_batch,
dict) / hasattr(forward_batch, "kv_shared_cache")).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-742: The delete of forward_batch at the top removes the
variable so locals().get("forward_batch", None) can never find it; restore
extend-mode detection by moving or removing the del forward_batch so that
forward_batch is still present when creating forward_batch_obj, or alternatively
capture forward_batch into a local variable (e.g., forward_batch_obj =
forward_batch or via locals().get before the del) before deleting; update the
logic around forward_batch_obj / is_extend_mode (used in the is_prefill
computation) to use that captured value so extend-mode detection
(is_extend_mode) works for 1-token prompts.

---

Nitpick comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1062-1068: The TypeError raised when attempting to call
iter(weights) loses the original exception context; update the try/except around
iter(weights) to capture the original exception (e.g., except TypeError as err)
and re-raise the new TypeError using "from err" so the stack trace is chained
and debugging shows the original error; locate the try/except that wraps
iter(weights) in the weights handling logic and modify that raise to include the
exception chaining.
- Around line 356-359: The code checks hasattr(forward_batch, "kv_shared_cache")
then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call
with direct attribute access (forward_batch.kv_shared_cache) to simplify the
style while preserving behavior—update the block that assigns shared_kv_cache
from forward_batch in the method where forward_batch is used (the snippet using
isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c588324c-685b-4a73-ae39-477690450de5

📥 Commits

Reviewing files that changed from the base of the PR and between 2126cae and b400087.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
  • pymllm/models/init.py

Comment thread pymllm/models/gemma3n.py Outdated
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from b400087 to fe694cb Compare May 1, 2026 09:27
@Grape203 Grape203 force-pushed the pr/gemma3n-text-only-native-20260501_135528 branch from fe694cb to 35dce02 Compare May 1, 2026 09:28
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)

1061-1067: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Preserve exception context when re-raising TypeError.

At Line 1064, re-raising without from err drops the original failure context and makes triage harder.

💡 Proposed fix
-        else:
-            try:
-                weight_items = iter(weights)
-            except TypeError:
-                raise TypeError(
-                    f"weights must be a dict-like state_dict, a module with state_dict(), "
-                    f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+        else:
+            try:
+                weight_items = iter(weights)
+            except TypeError as err:
+                raise TypeError(
+                    f"weights must be a dict-like state_dict, a module with state_dict(), "
+                    f"or an iterable of (name, tensor), got {type(weights)}"
+                ) from err
#!/bin/bash
# Verify non-chained re-raises for this code path.
rg -n -C2 'except TypeError( as \w+)?:' pymllm/models/gemma3n.py
rg -n -C2 'raise TypeError\(' pymllm/models/gemma3n.py
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pymllm/models/gemma3n.py` around lines 1061 - 1067, The except TypeError
block that attempts to validate weights should preserve the original exception
context: capture the caught exception (e.g., "except TypeError as e:") and
re-raise the new TypeError with "from e" so the original traceback is chained;
update the except block around the weights iterator creation (the try that sets
weight_items = iter(weights)) to use "except TypeError as e" and then "raise
TypeError(... ) from e" referencing the same error message.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 792-797: is_prefill detection currently uses only
input_ids_hf.shape[1] and self._hf_past_key_values, which lets a one-token new
request be misclassified as decode and reuse stale _hf_past_key_values; change
the logic so prefill is also true when the request boundary indicates a new
prompt (e.g., compare a stored request id/turn counter or track whether a
prefill has been initialized for the current request), and reset/clear
self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.

---

Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1061-1067: The except TypeError block that attempts to validate
weights should preserve the original exception context: capture the caught
exception (e.g., "except TypeError as e:") and re-raise the new TypeError with
"from e" so the original traceback is chained; update the except block around
the weights iterator creation (the try that sets weight_items = iter(weights))
to use "except TypeError as e" and then "raise TypeError(... ) from e"
referencing the same error message.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 77bef586-be68-4847-a4a0-57e7bc0bdd10

📥 Commits

Reviewing files that changed from the base of the PR and between b400087 and 35dce02.

📒 Files selected for processing (4)
  • pymllm/executor/model_runner.py
  • pymllm/models/__init__.py
  • pymllm/models/gemma3n.py
  • pymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
  • pymllm/models/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • pymllm/orchestrator/model_runner_process.py
  • pymllm/executor/model_runner.py

Comment thread pymllm/models/gemma3n.py Outdated
@chenghuaWang chenghuaWang changed the title Add Gemma3n text-only native server path feat(pymllm, gemma3n): Add Gemma3n text-only native server path May 5, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)

1110-1117: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Chain TypeError with from err for proper exception context.

Ruff B904 still flags this raise inside except TypeError: — the previous fix switched to iter(weights) but didn't preserve cause chaining. Without from err, the original TypeError from iter(...) is hidden behind a "During handling of the above exception" trace, which obscures the real failure when callers pass an unsupported weights type.

🩹 Proposed fix
         else:
             try:
                 weight_items = iter(weights)
-            except TypeError:
+            except TypeError as err:
                 raise TypeError(
                     f"weights must be a dict-like state_dict, a module with state_dict(), "
                     f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+                ) from err
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pymllm/models/gemma3n.py` around lines 1110 - 1117, The except TypeError
block currently re-raises a new TypeError without chaining the original one,
losing context; modify the handler to capture the original exception (e.g.,
except TypeError as err) and re-raise the new TypeError with "from err" so the
original iter(weights) error is preserved; update the block around the
iter(weights) call that raises the user-facing TypeError to use "except
TypeError as err: raise TypeError(... ) from err".
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1003-1013: Remove the misleading log call that says "falling back
to regular load_weights path." in the Gemma3n streaming loader branch: when
st_files (from model_path.glob("*.safetensors")) is empty, delete or stop
emitting the logger.info(...) message and let the subsequent raise
FileNotFoundError(...) for the missing safetensors shards stand alone; keep the
FileNotFoundError text as-is so callers see the actual failure (references:
st_files, model_path, logger.info, FileNotFoundError).

---

Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1110-1117: The except TypeError block currently re-raises a new
TypeError without chaining the original one, losing context; modify the handler
to capture the original exception (e.g., except TypeError as err) and re-raise
the new TypeError with "from err" so the original iter(weights) error is
preserved; update the block around the iter(weights) call that raises the
user-facing TypeError to use "except TypeError as err: raise TypeError(... )
from err".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 760e9807-8fba-4372-b9be-534771a651b2

📥 Commits

Reviewing files that changed from the base of the PR and between 35dce02 and 71a6f5b.

📒 Files selected for processing (2)
  • pymllm/executor/model_runner.py
  • pymllm/models/gemma3n.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • pymllm/executor/model_runner.py

Comment thread pymllm/models/gemma3n.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)

928-936: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Misleading "falling back" log right before raise FileNotFoundError.

The logger.info("...falling back to regular load_weights path.", ...) immediately precedes raise FileNotFoundError(...), so the message contradicts what actually happens and will mislead anyone debugging a missing-checkpoint scenario. Drop the stale log and let the exception speak for itself.

🩹 Proposed fix
             st_files = sorted(model_path.glob("*.safetensors"))
             if not st_files:
-                logger.info(
-                    "Gemma3n streaming loader found no safetensors under %s; "
-                    "falling back to regular load_weights path.",
-                    model_path,
-                )
                 raise FileNotFoundError(
                     f"No safetensors checkpoint shards found under {model_path}. "
                     "Gemma3n native loading currently expects safetensors weights."
                 )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pymllm/models/gemma3n.py` around lines 928 - 936, Remove the misleading
logger.info call that says "falling back to regular load_weights path"
immediately before the FileNotFoundError in the Gemma3n streaming loader; locate
the logger.info(...) that references model_path in pymllm/models/gemma3n.py (the
same block that then raises FileNotFoundError) and delete that log line so only
the FileNotFoundError with its clear message remains.
🧹 Nitpick comments (3)
pymllm/models/gemma3n.py (2)

1034-1040: ⚡ Quick win

Chain the re-raised TypeError to preserve the original cause.

Inside the except TypeError, re-raising without from (or from None) loses exception chaining context and trips Ruff B904. Re-raise with from err so debugging unexpected weights types preserves the underlying iterator failure.

♻️ Proposed fix
         else:
             try:
                 weight_items = iter(weights)
-            except TypeError:
+            except TypeError as err:
                 raise TypeError(
                     f"weights must be a dict-like state_dict, a module with state_dict(), "
                     f"or an iterable of (name, tensor), got {type(weights)}"
-                )
+                ) from err

As per coding guidelines: "Adhere to language-specific best practices and idioms (e.g., PEP 8 for Python, Google C++ Style Guide for C++)."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pymllm/models/gemma3n.py` around lines 1034 - 1040, The except block catching
TypeError when calling iter(weights) should preserve exception chaining: change
"except TypeError:" to "except TypeError as err" and re-raise the new TypeError
with "from err" (keeping the existing message) so the original iterator error is
preserved; target the code that assigns weight_items = iter(weights) and the
subsequent raise of TypeError for weights.

337-337: 💤 Low value

Replace constant getattr/setattr with direct attribute access.

Ruff flags B009 at Line 337 (getattr(forward_batch, "kv_shared_cache") after a hasattr check) and B010 at Line 641 (setattr(native_forward_batch, "kv_shared_cache", {})). Since both attribute names are string literals, prefer plain attribute access for readability and a small perf win.

♻️ Proposed fix
-        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
-            shared_kv_cache = getattr(forward_batch, "kv_shared_cache")
+        elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"):
+            shared_kv_cache = forward_batch.kv_shared_cache
         else:
             native_forward_batch = forward_batch
-            setattr(native_forward_batch, "kv_shared_cache", {})
+            native_forward_batch.kv_shared_cache = {}

Also applies to: 641-641

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pymllm/models/gemma3n.py` at line 337, The getattr/setattr calls using the
literal "kv_shared_cache" should be replaced with direct attribute access for
clarity and performance: in the code that currently does getattr(forward_batch,
"kv_shared_cache") (e.g., inside the forward_batch handling) use
forward_batch.kv_shared_cache directly, and where you do
setattr(native_forward_batch, "kv_shared_cache", {}) assign
native_forward_batch.kv_shared_cache = {} instead; update any related existence
checks (hasattr) to use direct attribute access or attribute existence checks as
needed.
pymllm/layers/gemma3n.py (1)

108-119: ⚡ Quick win

Precompute std_multiplier to avoid per-forward object creation.

activation_sparsity is a constant set at construction time, but _gaussian_topk instantiates a fresh torch.tensor and a torch.distributions.normal.Normal(0, 1) on every forward and recomputes icdf each call. Cache the multiplier once in __init__ so the hot path only does the mean/std/ReLU work.

♻️ Proposed refactor
     def __init__(
         self,
         hidden_size: int,
         intermediate_size: int,
         activation: str,
         activation_sparsity: float = 0.0,
     ):
         super().__init__()
         self.hidden_size = hidden_size
         self.intermediate_size = intermediate_size
         self.activation_name = activation
         self.act = _get_gemma3n_hidden_act_fn(activation)
         self.activation_sparsity = float(activation_sparsity)
+        if self.activation_sparsity > 0.0:
+            normal_dist = torch.distributions.normal.Normal(0.0, 1.0)
+            std_multiplier = normal_dist.icdf(
+                torch.tensor(self.activation_sparsity, dtype=torch.float32)
+            )
+            self.register_buffer(
+                "_std_multiplier", std_multiplier, persistent=False
+            )
@@
     def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
-        target_sparsity_tensor = torch.tensor(
-            self.activation_sparsity,
-            dtype=torch.float32,
-            device=inputs.device,
-        )
-        normal_dist = torch.distributions.normal.Normal(0, 1)
-        std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype)
+        std_multiplier = self._std_multiplier.to(
+            device=inputs.device, dtype=inputs.dtype
+        )
         inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
         inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
         cutoff_x = inputs_mean + inputs_std * std_multiplier
         return F.relu(inputs - cutoff_x)

As per coding guidelines: "Avoid unnecessary object creation in loops or hot paths."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pymllm/layers/gemma3n.py` around lines 108 - 119, The _gaussian_topk method
recreates torch.tensor and torch.distributions.normal.Normal and recomputes icdf
on every forward; move that work to construction: compute the std_multiplier
from self.activation_sparsity inside __init__ (e.g., compute normal.icdf once,
convert to the desired dtype and store as a tensor/buffer on the module) and
then in _gaussian_topk use the precomputed self.std_multiplier (or a registered
buffer name you choose) so the forward only computes inputs_mean, inputs_std,
cutoff_x and the final F.relu.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 928-936: Remove the misleading logger.info call that says "falling
back to regular load_weights path" immediately before the FileNotFoundError in
the Gemma3n streaming loader; locate the logger.info(...) that references
model_path in pymllm/models/gemma3n.py (the same block that then raises
FileNotFoundError) and delete that log line so only the FileNotFoundError with
its clear message remains.

---

Nitpick comments:
In `@pymllm/layers/gemma3n.py`:
- Around line 108-119: The _gaussian_topk method recreates torch.tensor and
torch.distributions.normal.Normal and recomputes icdf on every forward; move
that work to construction: compute the std_multiplier from
self.activation_sparsity inside __init__ (e.g., compute normal.icdf once,
convert to the desired dtype and store as a tensor/buffer on the module) and
then in _gaussian_topk use the precomputed self.std_multiplier (or a registered
buffer name you choose) so the forward only computes inputs_mean, inputs_std,
cutoff_x and the final F.relu.

In `@pymllm/models/gemma3n.py`:
- Around line 1034-1040: The except block catching TypeError when calling
iter(weights) should preserve exception chaining: change "except TypeError:" to
"except TypeError as err" and re-raise the new TypeError with "from err"
(keeping the existing message) so the original iterator error is preserved;
target the code that assigns weight_items = iter(weights) and the subsequent
raise of TypeError for weights.
- Line 337: The getattr/setattr calls using the literal "kv_shared_cache" should
be replaced with direct attribute access for clarity and performance: in the
code that currently does getattr(forward_batch, "kv_shared_cache") (e.g., inside
the forward_batch handling) use forward_batch.kv_shared_cache directly, and
where you do setattr(native_forward_batch, "kv_shared_cache", {}) assign
native_forward_batch.kv_shared_cache = {} instead; update any related existence
checks (hasattr) to use direct attribute access or attribute existence checks as
needed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7fed4568-f3ae-4072-a96e-d776f03ded72

📥 Commits

Reviewing files that changed from the base of the PR and between 71a6f5b and f309835.

📒 Files selected for processing (3)
  • pymllm/layers/__init__.py
  • pymllm/layers/gemma3n.py
  • pymllm/models/gemma3n.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant