Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace: bool argument to random.categorical to sample without replacement using Gumbel-top-k trick #25617

Open
carlosgmartin opened this issue Dec 19, 2024 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Dec 19, 2024

random.categorical uses the Gumbel-max trick to sample with replacement.

An extension of the Gumbel-max trick allows sampling without replacement. As noted in this paper:

The well-known Gumbel-Max trick for sampling from a categorical distribution can be extended to sample k elements without replacement.

The Gumbel-Max trick [...] allows sampling from the categorical distribution, simply by perturbing the log-probability for each category by adding independent Gumbel distributed noise and returning the category with maximum perturbed log-probability. [...] However, there is more: as was noted (in a blog post) by Vieira (2014), taking the top k largest perturbed log-probabilities (instead of the maximum, or top 1) yields a sample of size k from the categorical distribution without replacement. We refer to this extension of the Gumbel-Max trick as the Gumbel-Top-k trick.

Proposal:

Add a replace: bool argument to random.categorical. replace=False should use the Gumbel-Top-k trick (replacing argmax with top_k) to sample without replacement.

Alternatives:

random.choice can sample without replacement, but uses a less efficient cumsum-based method, and operates on probabilities rather than logits/log-probabilities. Related:

@carlosgmartin carlosgmartin added the enhancement New feature or request label Dec 19, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 23, 2024

Sounds reasonable – is this something you'd like to contribute?

@jakevdp jakevdp self-assigned this Dec 23, 2024
@carlosgmartin
Copy link
Contributor Author

Yes, I can create a PR for this.

@carlosgmartin
Copy link
Contributor Author

@jakevdp After thinking about it more, since the semantics of shapes and axes would differ from those of sampling with replacement, I think it might be better to create a separate function. What would be a good name for it? Here are a few candidates:

  • categorical_wor
  • categorical_without_replacement
  • categorical_no_replace
  • categorical_no_replacement
  • sample_without_replacement

Here's an example implementation:

from jax import lax, nn, numpy as jnp, random


def categorical_wor(key, logits, k, axis=-1, batch_shape=()):

    logits = jnp.broadcast_to(logits, batch_shape + logits.shape)
    if axis >= 0:
        axis += len(batch_shape)

    noise = random.gumbel(key, logits.shape, logits.dtype)
    x = logits + noise

    # return x.argmax(axis)  # sampling with replacement

    x = jnp.moveaxis(x, axis, -1)
    _, outcomes = lax.top_k(x, k)
    outcomes = jnp.moveaxis(outcomes, -1, axis)

    return outcomes


def get_freqs(x, length):
    counts = jnp.bincount(x.flatten(), length=length)
    return counts / counts.sum()


def main():
    key = random.key(0)

    key, subkey = random.split(key)
    logits = jnp.arange(4, dtype=float)
    print(f"probs: {nn.softmax(logits)}")

    axis = -1

    for k in range(logits.shape[axis] + 1):
        outcomes = categorical_wor(key, logits, k, axis, (20, 30, 40))
        print(f"{k=} freqs: {get_freqs(outcomes, logits.shape[axis])}")


if __name__ == "__main__":
    main()
probs: [0.0320586  0.08714432 0.23688284 0.6439143 ]
k=0 freqs: [nan nan nan nan]
k=1 freqs: [0.03220833 0.08495833 0.23579167 0.6470417 ]
k=2 freqs: [0.05175    0.13675    0.34770834 0.46379167]
k=3 freqs: [0.10051389 0.25186113 0.31643057 0.33119443]
k=4 freqs: [0.25 0.25 0.25 0.25]

(As expected, the sample frequencies become equal by the time we get to k == logits.shape[axis].)

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 13, 2025

Can you say more about why this needs its own function?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jan 14, 2025

The output shape of sampling with replacement is delete(logits.shape, axis).

The output shape of sampling without replacement would be insert(delete(logits.shape, axis), axis2, k), where k is the length of the sequential sampling-without-replacement process and axis2 is some axis (possibly axis).

Strictly speaking, it need not be its own function, but the argument signature as well as the documentation describing the output shape of random.categorical would need to be modified. Perhaps there's a risk of making the function too convoluted by overloading it with the task of doing two different things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants