-
Notifications
You must be signed in to change notification settings - Fork 757
Update tutorial examples to thread explicit RNGs #4975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
950099f to
3a7b077
Compare
|
When using Jupyter to edit the notebooks, I find that all the |
3a7b077 to
f5e5906
Compare
f5e5906 to
cdf9ff8
Compare
|
Thanks @samanklesaria ! Its looking great. Left some comments. |
|
Thanks @samanklesaria ! Wondering if we could also refactor the intro to the Randomness guide as part of this PR? Current intro is a bit outdated in the sense that it tries to explain it in terms of Haiku and Linen. It might be more useful add a simple code example at the beginning as motivation for the whole guide e.g: class Model(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1)
def __call__(self, x, *, rngs):
return nnx.relu(self.drop(self.linear(x), rngs=rngs))
rngs = nnx.Rngs(0)
model = Model(rngs=rngs) # pass rngs to initialize parameters
x = rngs.normal((32, 20)) # convenient jax.random methods
y = model(x, rngs=rngs) # pass rngs for dropout masksWe can remove the current section below in favor of the example above and some comments introducing |
cgarciae
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @samanklesaria ! This is great.
As a follow up PR we could change the Fork section to use .split instead.
|
ns module methods under jax.named_scope for profiling |
Closes #4971