-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add replace: bool
argument to random.categorical
to sample without replacement using Gumbel-top-k trick
#25617
Comments
Sounds reasonable – is this something you'd like to contribute? |
Yes, I can create a PR for this. |
@jakevdp After thinking about it more, since the semantics of shapes and axes would differ from those of sampling with replacement, I think it might be better to create a separate function. What would be a good name for it? Here are a few candidates:
Here's an example implementation: from jax import lax, nn, numpy as jnp, random
def categorical_wor(key, logits, k, axis=-1, batch_shape=()):
logits = jnp.broadcast_to(logits, batch_shape + logits.shape)
if axis >= 0:
axis += len(batch_shape)
noise = random.gumbel(key, logits.shape, logits.dtype)
x = logits + noise
# return x.argmax(axis) # sampling with replacement
x = jnp.moveaxis(x, axis, -1)
_, outcomes = lax.top_k(x, k)
outcomes = jnp.moveaxis(outcomes, -1, axis)
return outcomes
def get_freqs(x, length):
counts = jnp.bincount(x.flatten(), length=length)
return counts / counts.sum()
def main():
key = random.key(0)
key, subkey = random.split(key)
logits = jnp.arange(4, dtype=float)
print(f"probs: {nn.softmax(logits)}")
axis = -1
for k in range(logits.shape[axis] + 1):
outcomes = categorical_wor(key, logits, k, axis, (20, 30, 40))
print(f"{k=} freqs: {get_freqs(outcomes, logits.shape[axis])}")
if __name__ == "__main__":
main()
(As expected, the sample frequencies become equal by the time we get to |
Can you say more about why this needs its own function? |
The output shape of sampling with replacement is The output shape of sampling without replacement would be Strictly speaking, it need not be its own function, but the argument signature as well as the documentation describing the output shape of |
random.categorical
uses the Gumbel-max trick to sample with replacement.An extension of the Gumbel-max trick allows sampling without replacement. As noted in this paper:
Proposal:
Add a
replace: bool
argument torandom.categorical
.replace=False
should use the Gumbel-Top-k trick (replacingargmax
withtop_k
) to sample without replacement.Alternatives:
random.choice
can sample without replacement, but uses a less efficientcumsum
-based method, and operates on probabilities rather than logits/log-probabilities. Related:The text was updated successfully, but these errors were encountered: