Skip to content

Conversation

@noahho
Copy link
Collaborator

@noahho noahho commented Aug 27, 2025

Motivation and Context


Public API Changes

  • No Public API changes
  • Yes, Public API changes (Details below)

How Has This Been Tested?


Checklist

  • The changes have been tested locally.
  • Documentation has been updated (if the public API or usage changes).
  • A entry has been added to CHANGELOG.md (if relevant for users).
  • The code follows the project's style guidelines.
  • I have considered the impact of these changes on the public API.

noahho and others added 30 commits July 15, 2025 16:32
# Conflicts:
#	examples/notebooks/TabPFN_Demo_Local.ipynb
#	src/tabpfn/preprocessing.py
Copilot AI review requested due to automatic review settings August 27, 2025 09:24
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the generate_index_permutations function to support subsampling with replacement, providing more flexibility for data sampling strategies.

  • Adds a with_replacement parameter to enable sampling indices multiple times
  • Refactors parameter validation logic for better clarity and consistency
  • Updates documentation to reflect the new functionality and parameter changes

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces subsampling with replacement by adding a with_replacement flag to the generate_index_permutations function. The implementation is well-structured and correctly uses numpy.random.Generator.choice for sampling with replacement. The refactoring also improves the existing logic for determining the sample size and provides clearer validation and error messages.

My main feedback is to address an edge case where max_index is 0, which can cause a ValueError when sampling with replacement. Adding a check at the beginning of the function will make it more robust.

Comment on lines 289 to 310
_, rng = infer_random_state(random_state)
if isinstance(subsample, int):
if subsample < 1:
raise ValueError(f"{subsample=} must be larger than 1 if int")
subsample = min(subsample, max_index)

return [rng.permutation(max_index)[:subsample] for _ in range(n)]

# Determine the number of items to subsample (k)
if isinstance(subsample, float):
if not (0 < subsample < 1):
raise ValueError(f"{subsample=} must be in (0, 1) if float")
subsample = int(subsample * max_index) + 1
return [rng.permutation(max_index)[:subsample] for _ in range(n)]

raise ValueError(f"{subsample=} must be int or float.")
if not (0.0 < subsample <= 1.0):
raise ValueError(f"If float, {subsample=} must be in (0, 1].")
# Ensure at least one sample is drawn
k = max(1, int(subsample * max_index))
elif isinstance(subsample, int):
if subsample < 1:
raise ValueError(f"If int, {subsample=} must be at least 1.")
k = subsample
else:
raise TypeError(f"{subsample=} must be an int or float.")

# Generate n lists of indices based on the replacement strategy
if with_replacement:
# Sample with replacement. The sample size `k` can be larger than `max_index`.
return [rng.choice(max_index, size=k, replace=True) for _ in range(n)]
# Sample without replacement. The sample size cannot exceed the population size.
sample_size = min(k, max_index)
return [rng.permutation(max_index)[:sample_size] for _ in range(n)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When max_index is 0, k is calculated as at least 1. This causes rng.choice(0, size=k, ...) to raise a ValueError when sampling with replacement, as it's impossible to sample from an empty population. The behavior is also inconsistent with the with_replacement=False path, which correctly returns an empty array.

To handle this edge case gracefully and ensure consistent behavior, it's best to add a check for max_index == 0 at the beginning of the function's logic. It would also be good practice to validate that max_index is not negative.

    if max_index < 0:
        raise ValueError(f"max_index must be non-negative, but got {max_index}")
    if max_index == 0:
        return [np.array([], dtype=np.int64) for _ in range(n)]

    _, rng = infer_random_state(random_state)

    # Determine the number of items to subsample (k)
    if isinstance(subsample, float):
        if not (0.0 < subsample <= 1.0):
            raise ValueError(f"If float, {subsample=} must be in (0, 1].")
        # Ensure at least one sample is drawn
        k = max(1, int(subsample * max_index))
    elif isinstance(subsample, int):
        if subsample < 1:
            raise ValueError(f"If int, {subsample=} must be at least 1.")
        k = subsample
    else:
        raise TypeError(f"{subsample=} must be an int or float.")

    # Generate n lists of indices based on the replacement strategy
    if with_replacement:
        # Sample with replacement. The sample size `k` can be larger than `max_index`.
        return [rng.choice(max_index, size=k, replace=True) for _ in range(n)]
    # Sample without replacement. The sample size cannot exceed the population size.
    sample_size = min(k, max_index)
    return [rng.permutation(max_index)[:sample_size] for _ in range(n)]

@LeoGrin LeoGrin requested review from LeoGrin and removed request for LeoGrin October 14, 2025 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants