feat(pymllm, gemma3n): Add Gemma3n text-only native server path#672
feat(pymllm, gemma3n): Add Gemma3n text-only native server path#672Grape203 wants to merge 12 commits intoUbiquitousLearning:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a full Gemma 3n text-model implementation and registry entries, updates layer exports, changes model instantiation to optionally perform CPU-first instantiation and prefer model-provided weight loaders, passes sliding-window size into the FlashInfer attention backend, and restricts radix-cache-specific logic to caches that expose radix semantics. (50 words) ChangesGemma 3n model & registry
Model instantiation, weight loading & attention wiring
Cache handling guards
Sequence Diagram(s)sequenceDiagram
participant Runner as ModelRunner
participant Model as ModelClass
participant Storage as CheckpointStorage
participant Device as DeviceMgr
Runner->>Model: inspect requires_cpu_first_weight_loading
alt CPU-first required
Runner->>Device: set instantiate_device="cpu"
Runner-->>Runner: temporarily set default dtype
else
Runner->>Device: set instantiate_device=runtime_device
end
Runner->>Model: instantiate model on instantiate_device
Runner->>Model: evaluate use_model_path_weight_loader (bool/callable)
alt model provides load_weights_from_model_path and resolver true
Runner->>Storage: request model_path streaming
Storage-->>Model: stream checkpoint path/data
Model->>Model: load_weights_from_model_path(model_path)
else
Runner->>Storage: iterate weight tensors
Storage-->>Runner: weight iterator
Runner->>Model: load_weights(weight_iterator)
end
alt runtime device is CUDA and model has move_compute_modules_to_device
Runner->>Model: move_compute_modules_to_device(runtime_cuda_device)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-727: The code infers prefill vs decode from sequence length
and cache emptiness which lets a fresh 1-token request reuse stale context;
update the logic in the native path (around input_ids_hf/position_ids_hf
handling and the is_prefill decision) to detect request boundaries and reset
cached state when a new decode begins: if forward_batch.forward_mode indicates
"decode" (or if position_ids_hf[...,0] == 0) then clear
self._native_cached_input_ids, self._native_cached_positions and related cached
tensors such as self._hf_past_key_values before computing is_prefill so a
single-token new request does not append to prior context (apply same change to
the other similar branch handling caches).
- Around line 1041-1046: The weights handling in ModelRunner.load_model()
materializes iterables with list(weights), which will drain generator-like
inputs such as self._iter_weights and cause OOM for large Gemma3n checkpoints;
change the logic to treat non-dict weights as an iterable and iterate over it
directly (e.g., use a for-loop over weights or assign weight_items = weights
when it's already an iterable) rather than calling list(weights); if you keep
the TypeError branch for non-iterables, re-raise or raise a new error using
"from err" to preserve exception chaining.
- Around line 939-946: The current Gemma3n streaming loader returns
self.load_weights([]) when no .safetensors are found, which silently yields an
uninitialized model; instead, change the behavior in the st_files check inside
the Gemma3n loader: either raise a clear exception (e.g., FileNotFoundError or
ValueError) with a message that no .safetensors were found for the given
model_path so callers (including ModelRunner.load_model when
use_model_path_weight_loader is enabled) fail fast, or implement a proper .bin
fallback by detecting "*.bin" files and delegating to the existing .bin loading
code (call the appropriate bin-weight loader method) rather than returning an
empty load_weights list; update references to load_weights and the loader
selection logic accordingly.
In `@pymllm/orchestrator/model_runner_process.py`:
- Around line 501-504: The cleanup path incorrectly uses a simple "cache is not
None" check so ChunkCache (which lacks page_size) can still run radix cleanup
and set did_insert=True, preventing KV frees; update the cleanup logic to mirror
the insertion guard by checking for radix-capable cache (e.g., "if cache is not
None and hasattr(cache, 'page_size')" or equivalent) before calling
_free_rid_resources or setting did_insert, and ensure _free_rid_resources only
treats did_insert as true when a real radix insertion occurred (adjust
_free_rid_resources and caller logic in model_runner_process.py to use the same
hasattr(cache, 'page_size') predicate to decide insertion/cleanup and never mark
did_insert for ChunkCache).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 63329a33-6882-4fdc-bc8e-b9cf3b71fcd7
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
2126cae to
b400087
Compare
|
Updated the PR to address the CodeRabbit comments:
Also re-ran the local pymllm-server verification with Gemma3n E2B text weights and /v1/completions still returns: |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
pymllm/models/gemma3n.py (2)
1062-1068: ⚡ Quick winChain the exception for better debugging context.
The
TypeErrorraised whenweightsis not iterable should be chained with the original exception usingfrom errto preserve the exception context and aid debugging.♻️ Proposed fix
else: try: weight_items = iter(weights) - except TypeError: + except TypeError as err: raise TypeError( f"weights must be a dict-like state_dict, a module with state_dict(), " f"or an iterable of (name, tensor), got {type(weights)}" - ) + ) from err🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 1062 - 1068, The TypeError raised when attempting to call iter(weights) loses the original exception context; update the try/except around iter(weights) to capture the original exception (e.g., except TypeError as err) and re-raise the new TypeError using "from err" so the stack trace is chained and debugging shows the original error; locate the try/except that wraps iter(weights) in the weights handling logic and modify that raise to include the exception chaining.
356-359: 💤 Low valueMinor style:
getattrwith constant attribute afterhasattrcheck.Line 359 uses
getattr(forward_batch, "kv_shared_cache")immediately after checkinghasattr(forward_batch, "kv_shared_cache")on line 358. This can be simplified to direct attribute access since the attribute's existence is already verified.♻️ Suggested simplification
- elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): - shared_kv_cache = getattr(forward_batch, "kv_shared_cache") + elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): + shared_kv_cache = forward_batch.kv_shared_cache🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 356 - 359, The code checks hasattr(forward_batch, "kv_shared_cache") then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call with direct attribute access (forward_batch.kv_shared_cache) to simplify the style while preserving behavior—update the block that assigns shared_kv_cache from forward_batch in the method where forward_batch is used (the snippet using isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 694-742: The delete of forward_batch at the top removes the
variable so locals().get("forward_batch", None) can never find it; restore
extend-mode detection by moving or removing the del forward_batch so that
forward_batch is still present when creating forward_batch_obj, or alternatively
capture forward_batch into a local variable (e.g., forward_batch_obj =
forward_batch or via locals().get before the del) before deleting; update the
logic around forward_batch_obj / is_extend_mode (used in the is_prefill
computation) to use that captured value so extend-mode detection
(is_extend_mode) works for 1-token prompts.
---
Nitpick comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1062-1068: The TypeError raised when attempting to call
iter(weights) loses the original exception context; update the try/except around
iter(weights) to capture the original exception (e.g., except TypeError as err)
and re-raise the new TypeError using "from err" so the stack trace is chained
and debugging shows the original error; locate the try/except that wraps
iter(weights) in the weights handling logic and modify that raise to include the
exception chaining.
- Around line 356-359: The code checks hasattr(forward_batch, "kv_shared_cache")
then uses getattr(forward_batch, "kv_shared_cache"); replace the getattr call
with direct attribute access (forward_batch.kv_shared_cache) to simplify the
style while preserving behavior—update the block that assigns shared_kv_cache
from forward_batch in the method where forward_batch is used (the snippet using
isinstance(forward_batch, dict) / hasattr(forward_batch, "kv_shared_cache")).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c588324c-685b-4a73-ae39-477690450de5
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
- pymllm/models/init.py
b400087 to
fe694cb
Compare
fe694cb to
35dce02
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)
1061-1067:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winPreserve exception context when re-raising
TypeError.At Line 1064, re-raising without
from errdrops the original failure context and makes triage harder.💡 Proposed fix
- else: - try: - weight_items = iter(weights) - except TypeError: - raise TypeError( - f"weights must be a dict-like state_dict, a module with state_dict(), " - f"or an iterable of (name, tensor), got {type(weights)}" - ) + else: + try: + weight_items = iter(weights) + except TypeError as err: + raise TypeError( + f"weights must be a dict-like state_dict, a module with state_dict(), " + f"or an iterable of (name, tensor), got {type(weights)}" + ) from err#!/bin/bash # Verify non-chained re-raises for this code path. rg -n -C2 'except TypeError( as \w+)?:' pymllm/models/gemma3n.py rg -n -C2 'raise TypeError\(' pymllm/models/gemma3n.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@pymllm/models/gemma3n.py` around lines 1061 - 1067, The except TypeError block that attempts to validate weights should preserve the original exception context: capture the caught exception (e.g., "except TypeError as e:") and re-raise the new TypeError with "from e" so the original traceback is chained; update the except block around the weights iterator creation (the try that sets weight_items = iter(weights)) to use "except TypeError as e" and then "raise TypeError(... ) from e" referencing the same error message.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 792-797: is_prefill detection currently uses only
input_ids_hf.shape[1] and self._hf_past_key_values, which lets a one-token new
request be misclassified as decode and reuse stale _hf_past_key_values; change
the logic so prefill is also true when the request boundary indicates a new
prompt (e.g., compare a stored request id/turn counter or track whether a
prefill has been initialized for the current request), and reset/clear
self._hf_past_key_values when a new request begins; update the
is_prefill/past_key_values decision (the variables input_ids_hf, is_prefill,
past_key_values and self._hf_past_key_values) to consult that explicit
per-request boundary flag instead of relying only on token count.
---
Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1061-1067: The except TypeError block that attempts to validate
weights should preserve the original exception context: capture the caught
exception (e.g., "except TypeError as e:") and re-raise the new TypeError with
"from e" so the original traceback is chained; update the except block around
the weights iterator creation (the try that sets weight_items = iter(weights))
to use "except TypeError as e" and then "raise TypeError(... ) from e"
referencing the same error message.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 77bef586-be68-4847-a4a0-57e7bc0bdd10
📒 Files selected for processing (4)
pymllm/executor/model_runner.pypymllm/models/__init__.pypymllm/models/gemma3n.pypymllm/orchestrator/model_runner_process.py
✅ Files skipped from review due to trivial changes (1)
- pymllm/models/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- pymllm/orchestrator/model_runner_process.py
- pymllm/executor/model_runner.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)
1110-1117:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winChain
TypeErrorwithfrom errfor proper exception context.Ruff B904 still flags this
raiseinsideexcept TypeError:— the previous fix switched toiter(weights)but didn't preserve cause chaining. Withoutfrom err, the originalTypeErrorfromiter(...)is hidden behind a "During handling of the above exception" trace, which obscures the real failure when callers pass an unsupportedweightstype.🩹 Proposed fix
else: try: weight_items = iter(weights) - except TypeError: + except TypeError as err: raise TypeError( f"weights must be a dict-like state_dict, a module with state_dict(), " f"or an iterable of (name, tensor), got {type(weights)}" - ) + ) from err🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@pymllm/models/gemma3n.py` around lines 1110 - 1117, The except TypeError block currently re-raises a new TypeError without chaining the original one, losing context; modify the handler to capture the original exception (e.g., except TypeError as err) and re-raise the new TypeError with "from err" so the original iter(weights) error is preserved; update the block around the iter(weights) call that raises the user-facing TypeError to use "except TypeError as err: raise TypeError(... ) from err".
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1003-1013: Remove the misleading log call that says "falling back
to regular load_weights path." in the Gemma3n streaming loader branch: when
st_files (from model_path.glob("*.safetensors")) is empty, delete or stop
emitting the logger.info(...) message and let the subsequent raise
FileNotFoundError(...) for the missing safetensors shards stand alone; keep the
FileNotFoundError text as-is so callers see the actual failure (references:
st_files, model_path, logger.info, FileNotFoundError).
---
Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 1110-1117: The except TypeError block currently re-raises a new
TypeError without chaining the original one, losing context; modify the handler
to capture the original exception (e.g., except TypeError as err) and re-raise
the new TypeError with "from err" so the original iter(weights) error is
preserved; update the block around the iter(weights) call that raises the
user-facing TypeError to use "except TypeError as err: raise TypeError(... )
from err".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 760e9807-8fba-4372-b9be-534771a651b2
📒 Files selected for processing (2)
pymllm/executor/model_runner.pypymllm/models/gemma3n.py
🚧 Files skipped from review as they are similar to previous changes (1)
- pymllm/executor/model_runner.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
pymllm/models/gemma3n.py (1)
928-936:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winMisleading "falling back" log right before
raise FileNotFoundError.The
logger.info("...falling back to regular load_weights path.", ...)immediately precedesraise FileNotFoundError(...), so the message contradicts what actually happens and will mislead anyone debugging a missing-checkpoint scenario. Drop the stale log and let the exception speak for itself.🩹 Proposed fix
st_files = sorted(model_path.glob("*.safetensors")) if not st_files: - logger.info( - "Gemma3n streaming loader found no safetensors under %s; " - "falling back to regular load_weights path.", - model_path, - ) raise FileNotFoundError( f"No safetensors checkpoint shards found under {model_path}. " "Gemma3n native loading currently expects safetensors weights." )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@pymllm/models/gemma3n.py` around lines 928 - 936, Remove the misleading logger.info call that says "falling back to regular load_weights path" immediately before the FileNotFoundError in the Gemma3n streaming loader; locate the logger.info(...) that references model_path in pymllm/models/gemma3n.py (the same block that then raises FileNotFoundError) and delete that log line so only the FileNotFoundError with its clear message remains.
🧹 Nitpick comments (3)
pymllm/models/gemma3n.py (2)
1034-1040: ⚡ Quick winChain the re-raised
TypeErrorto preserve the original cause.Inside the
except TypeError, re-raising withoutfrom(orfrom None) loses exception chaining context and trips Ruff B904. Re-raise withfrom errso debugging unexpectedweightstypes preserves the underlying iterator failure.♻️ Proposed fix
else: try: weight_items = iter(weights) - except TypeError: + except TypeError as err: raise TypeError( f"weights must be a dict-like state_dict, a module with state_dict(), " f"or an iterable of (name, tensor), got {type(weights)}" - ) + ) from errAs per coding guidelines: "Adhere to language-specific best practices and idioms (e.g., PEP 8 for Python, Google C++ Style Guide for C++)."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@pymllm/models/gemma3n.py` around lines 1034 - 1040, The except block catching TypeError when calling iter(weights) should preserve exception chaining: change "except TypeError:" to "except TypeError as err" and re-raise the new TypeError with "from err" (keeping the existing message) so the original iterator error is preserved; target the code that assigns weight_items = iter(weights) and the subsequent raise of TypeError for weights.
337-337: 💤 Low valueReplace constant
getattr/setattrwith direct attribute access.Ruff flags B009 at Line 337 (
getattr(forward_batch, "kv_shared_cache")after ahasattrcheck) and B010 at Line 641 (setattr(native_forward_batch, "kv_shared_cache", {})). Since both attribute names are string literals, prefer plain attribute access for readability and a small perf win.♻️ Proposed fix
- elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): - shared_kv_cache = getattr(forward_batch, "kv_shared_cache") + elif forward_batch is not None and hasattr(forward_batch, "kv_shared_cache"): + shared_kv_cache = forward_batch.kv_shared_cacheelse: native_forward_batch = forward_batch - setattr(native_forward_batch, "kv_shared_cache", {}) + native_forward_batch.kv_shared_cache = {}Also applies to: 641-641
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@pymllm/models/gemma3n.py` at line 337, The getattr/setattr calls using the literal "kv_shared_cache" should be replaced with direct attribute access for clarity and performance: in the code that currently does getattr(forward_batch, "kv_shared_cache") (e.g., inside the forward_batch handling) use forward_batch.kv_shared_cache directly, and where you do setattr(native_forward_batch, "kv_shared_cache", {}) assign native_forward_batch.kv_shared_cache = {} instead; update any related existence checks (hasattr) to use direct attribute access or attribute existence checks as needed.pymllm/layers/gemma3n.py (1)
108-119: ⚡ Quick winPrecompute
std_multiplierto avoid per-forward object creation.
activation_sparsityis a constant set at construction time, but_gaussian_topkinstantiates a freshtorch.tensorand atorch.distributions.normal.Normal(0, 1)on every forward and recomputesicdfeach call. Cache the multiplier once in__init__so the hot path only does the mean/std/ReLU work.♻️ Proposed refactor
def __init__( self, hidden_size: int, intermediate_size: int, activation: str, activation_sparsity: float = 0.0, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.activation_name = activation self.act = _get_gemma3n_hidden_act_fn(activation) self.activation_sparsity = float(activation_sparsity) + if self.activation_sparsity > 0.0: + normal_dist = torch.distributions.normal.Normal(0.0, 1.0) + std_multiplier = normal_dist.icdf( + torch.tensor(self.activation_sparsity, dtype=torch.float32) + ) + self.register_buffer( + "_std_multiplier", std_multiplier, persistent=False + ) @@ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor: - target_sparsity_tensor = torch.tensor( - self.activation_sparsity, - dtype=torch.float32, - device=inputs.device, - ) - normal_dist = torch.distributions.normal.Normal(0, 1) - std_multiplier = normal_dist.icdf(target_sparsity_tensor).to(dtype=inputs.dtype) + std_multiplier = self._std_multiplier.to( + device=inputs.device, dtype=inputs.dtype + ) inputs_mean = torch.mean(inputs, dim=-1, keepdim=True) inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False) cutoff_x = inputs_mean + inputs_std * std_multiplier return F.relu(inputs - cutoff_x)As per coding guidelines: "Avoid unnecessary object creation in loops or hot paths."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@pymllm/layers/gemma3n.py` around lines 108 - 119, The _gaussian_topk method recreates torch.tensor and torch.distributions.normal.Normal and recomputes icdf on every forward; move that work to construction: compute the std_multiplier from self.activation_sparsity inside __init__ (e.g., compute normal.icdf once, convert to the desired dtype and store as a tensor/buffer on the module) and then in _gaussian_topk use the precomputed self.std_multiplier (or a registered buffer name you choose) so the forward only computes inputs_mean, inputs_std, cutoff_x and the final F.relu.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@pymllm/models/gemma3n.py`:
- Around line 928-936: Remove the misleading logger.info call that says "falling
back to regular load_weights path" immediately before the FileNotFoundError in
the Gemma3n streaming loader; locate the logger.info(...) that references
model_path in pymllm/models/gemma3n.py (the same block that then raises
FileNotFoundError) and delete that log line so only the FileNotFoundError with
its clear message remains.
---
Nitpick comments:
In `@pymllm/layers/gemma3n.py`:
- Around line 108-119: The _gaussian_topk method recreates torch.tensor and
torch.distributions.normal.Normal and recomputes icdf on every forward; move
that work to construction: compute the std_multiplier from
self.activation_sparsity inside __init__ (e.g., compute normal.icdf once,
convert to the desired dtype and store as a tensor/buffer on the module) and
then in _gaussian_topk use the precomputed self.std_multiplier (or a registered
buffer name you choose) so the forward only computes inputs_mean, inputs_std,
cutoff_x and the final F.relu.
In `@pymllm/models/gemma3n.py`:
- Around line 1034-1040: The except block catching TypeError when calling
iter(weights) should preserve exception chaining: change "except TypeError:" to
"except TypeError as err" and re-raise the new TypeError with "from err"
(keeping the existing message) so the original iterator error is preserved;
target the code that assigns weight_items = iter(weights) and the subsequent
raise of TypeError for weights.
- Line 337: The getattr/setattr calls using the literal "kv_shared_cache" should
be replaced with direct attribute access for clarity and performance: in the
code that currently does getattr(forward_batch, "kv_shared_cache") (e.g., inside
the forward_batch handling) use forward_batch.kv_shared_cache directly, and
where you do setattr(native_forward_batch, "kv_shared_cache", {}) assign
native_forward_batch.kv_shared_cache = {} instead; update any related existence
checks (hasattr) to use direct attribute access or attribute existence checks as
needed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7fed4568-f3ae-4072-a96e-d776f03ded72
📒 Files selected for processing (3)
pymllm/layers/__init__.pypymllm/layers/gemma3n.pypymllm/models/gemma3n.py
Summary
This PR adds an initial Gemma3n text-only native path to pymllm.
The implementation focuses on the simplest text-only LLM path first, following the staged direction discussed with the maintainer. It supports loading Gemma3n E2B text weights, native text forward/generation, and basic
pymllm-serverdecode with RadixCache disabled.Main changes
pymllm.models.gemma3nwith a text-only native Gemma3n implementation.pymllm.models.disable_radix_cachepath so ChunkCache does not enter RadixCache-specific insertion logic.Verification
Tested locally with Gemma3n E2B text weights.
loaded=732skipped=825missing_in_ckpt=0pymllm-serverworks with RadixCache disabled./v1/completionsmulti-token decode returns: