Skip to content

Conversation

@samanklesaria
Copy link
Collaborator

Closes #4971

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@samanklesaria samanklesaria force-pushed the issues/4971 branch 4 times, most recently from 950099f to 3a7b077 Compare September 26, 2025 16:47
@samanklesaria
Copy link
Collaborator Author

samanklesaria commented Sep 26, 2025

When using Jupyter to edit the notebooks, I find that all the {code-cell} blocks in the markdown get replaced with {code-cell} ipython3. That is, jupytext on my machine labels all code cells with the language, while on the main branch, code cells are inconsistently labeled. In the mnist tutorial there's no ipython3 label, but in the randomness tutorial there is. Instead of figuring out how to configure my jupyter settings separately for each notebook file, I'm just going to leave this commit including the extra ipython3 labels.

@samanklesaria samanklesaria changed the title Update tutorial examples to thread explicit RNGs (WIP) Update tutorial examples to thread explicit RNGs Sep 26, 2025
@samanklesaria samanklesaria marked this pull request as ready for review September 26, 2025 17:45
@cgarciae
Copy link
Collaborator

cgarciae commented Oct 2, 2025

Thanks @samanklesaria ! Its looking great. Left some comments.
Maybe the other place we can update is the index.rst which also contains an example that uses dropout.

@samanklesaria samanklesaria requested a review from cgarciae October 2, 2025 18:13
@cgarciae
Copy link
Collaborator

Thanks @samanklesaria ! Wondering if we could also refactor the intro to the Randomness guide as part of this PR? Current intro is a bit outdated in the sense that it tries to explain it in terms of Haiku and Linen. It might be more useful add a simple code example at the beginning as motivation for the whole guide e.g:

class Model(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.linear = nnx.Linear(20, 10, rngs=rngs)
    self.drop = nnx.Dropout(0.1)

  def __call__(self, x, *, rngs):
    return nnx.relu(self.drop(self.linear(x), rngs=rngs))
 
rngs = nnx.Rngs(0)
model = Model(rngs=rngs)  # pass rngs to initialize parameters
x = rngs.normal((32, 20))  # convenient jax.random methods
y = model(x, rngs=rngs)  # pass rngs for dropout masks

We can remove the current section below in favor of the example above and some comments introducing nnx.Rngs as the main mechanism that NNX uses for propagating random state.

Random state handling in Flax NNX was radically simplified compared to systems like Haiku and Flax Linen because Flax NNX defines the random state as an object state. In essence, this means that in Flax NNX, the random state is: 1) just another type of state; 2) stored in nnx.Variables; and 3) held by the models themselves.

The Flax NNX pseudorandom number generator (PRNG) system has the following main characteristics:

It is explicit.

It is order-based.

It uses dynamic counters.

This is a bit different from Flax Linen’s PRNG system, which is (path + order)-based, and uses static counters.

Copy link
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @samanklesaria ! This is great.
As a follow up PR we could change the Fork section to use .split instead.

@copybara-service copybara-service bot merged commit 74985b2 into google:main Oct 30, 2025
18 checks passed
@Tsukimarf
Copy link

ns module methods under jax.named_scope for profiling

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Emphasize Explicit Randomness in Documentation

3 participants