- 
                Notifications
    
You must be signed in to change notification settings  - Fork 465
 
Subsampling with replacement #471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Conflicts: # examples/notebooks/TabPFN_Demo_Local.ipynb # src/tabpfn/preprocessing.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the generate_index_permutations function to support subsampling with replacement, providing more flexibility for data sampling strategies.
- Adds a 
with_replacementparameter to enable sampling indices multiple times - Refactors parameter validation logic for better clarity and consistency
 - Updates documentation to reflect the new functionality and parameter changes
 
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces subsampling with replacement by adding a with_replacement flag to the generate_index_permutations function. The implementation is well-structured and correctly uses numpy.random.Generator.choice for sampling with replacement. The refactoring also improves the existing logic for determining the sample size and provides clearer validation and error messages.
My main feedback is to address an edge case where max_index is 0, which can cause a ValueError when sampling with replacement. Adding a check at the beginning of the function will make it more robust.
| _, rng = infer_random_state(random_state) | ||
| if isinstance(subsample, int): | ||
| if subsample < 1: | ||
| raise ValueError(f"{subsample=} must be larger than 1 if int") | ||
| subsample = min(subsample, max_index) | ||
| 
               | 
          ||
| return [rng.permutation(max_index)[:subsample] for _ in range(n)] | ||
| 
               | 
          ||
| # Determine the number of items to subsample (k) | ||
| if isinstance(subsample, float): | ||
| if not (0 < subsample < 1): | ||
| raise ValueError(f"{subsample=} must be in (0, 1) if float") | ||
| subsample = int(subsample * max_index) + 1 | ||
| return [rng.permutation(max_index)[:subsample] for _ in range(n)] | ||
| 
               | 
          ||
| raise ValueError(f"{subsample=} must be int or float.") | ||
| if not (0.0 < subsample <= 1.0): | ||
| raise ValueError(f"If float, {subsample=} must be in (0, 1].") | ||
| # Ensure at least one sample is drawn | ||
| k = max(1, int(subsample * max_index)) | ||
| elif isinstance(subsample, int): | ||
| if subsample < 1: | ||
| raise ValueError(f"If int, {subsample=} must be at least 1.") | ||
| k = subsample | ||
| else: | ||
| raise TypeError(f"{subsample=} must be an int or float.") | ||
| 
               | 
          ||
| # Generate n lists of indices based on the replacement strategy | ||
| if with_replacement: | ||
| # Sample with replacement. The sample size `k` can be larger than `max_index`. | ||
| return [rng.choice(max_index, size=k, replace=True) for _ in range(n)] | ||
| # Sample without replacement. The sample size cannot exceed the population size. | ||
| sample_size = min(k, max_index) | ||
| return [rng.permutation(max_index)[:sample_size] for _ in range(n)] | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When max_index is 0, k is calculated as at least 1. This causes rng.choice(0, size=k, ...) to raise a ValueError when sampling with replacement, as it's impossible to sample from an empty population. The behavior is also inconsistent with the with_replacement=False path, which correctly returns an empty array.
To handle this edge case gracefully and ensure consistent behavior, it's best to add a check for max_index == 0 at the beginning of the function's logic. It would also be good practice to validate that max_index is not negative.
    if max_index < 0:
        raise ValueError(f"max_index must be non-negative, but got {max_index}")
    if max_index == 0:
        return [np.array([], dtype=np.int64) for _ in range(n)]
    _, rng = infer_random_state(random_state)
    # Determine the number of items to subsample (k)
    if isinstance(subsample, float):
        if not (0.0 < subsample <= 1.0):
            raise ValueError(f"If float, {subsample=} must be in (0, 1].")
        # Ensure at least one sample is drawn
        k = max(1, int(subsample * max_index))
    elif isinstance(subsample, int):
        if subsample < 1:
            raise ValueError(f"If int, {subsample=} must be at least 1.")
        k = subsample
    else:
        raise TypeError(f"{subsample=} must be an int or float.")
    # Generate n lists of indices based on the replacement strategy
    if with_replacement:
        # Sample with replacement. The sample size `k` can be larger than `max_index`.
        return [rng.choice(max_index, size=k, replace=True) for _ in range(n)]
    # Sample without replacement. The sample size cannot exceed the population size.
    sample_size = min(k, max_index)
    return [rng.permutation(max_index)[:sample_size] for _ in range(n)]Co-authored-by: Copilot <[email protected]>
…to subsampling-with-replacement
Motivation and Context
Public API Changes
How Has This Been Tested?
Checklist
CHANGELOG.md(if relevant for users).