Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support empty batches #530

Closed
wants to merge 11 commits into from
Closed

Support empty batches #530

wants to merge 11 commits into from

Conversation

ffuuugor
Copy link
Contributor

Background

Poisson sampling can sometimes result in an empty input batch, especially if a sampling rate (i.e. expected batch size) is small. This is not out of the ordinary and should be handled accordingly - gradients (signal) should be set to 0 and noise should still be added.

We've made an attempt to support this behaviour, but it wasn't fully covered with tests and got broken over time. As a result, at the moment we have a DataLoader that is capable of producing zero-sized batches, GradSampleModule that only partially supports them and DPOptimizer that doesn't support them at all

This PR addresses Issue #522 (thanks @xichens for reporting)

Improvements

This diff fixes the following

  • DPOptimizer can now handle empty batches
  • BatchMemoryManager can now handle empty batches
  • Adds a PrivacyEngine test with empty batches
  • Adds BatchMemoryManager test with empty batches
  • DataLoader now respects dtype of the inputs (i.e. empty batches only used to work with float input tensors)
  • ExpandedWeights still can's process empty batches, which we call out in our readme (FYI @samdow )

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 25, 2022
@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@alexandresablayrolles alexandresablayrolles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall but left a few comments.

x: any object

Returns:
``x.shape`` if attribute exists, empty tuple otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True

@@ -144,6 +162,7 @@ def __init__(
generator=generator,
)
sample_empty_shapes = [[0, *shape_safe(x)] for x in dataset[0]]
dtypes = [dtype_safe(x) for x in dataset[0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is consumed by wrap_collate_with_empty, we need the dtype to be an actual dtype and not a type right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.zeros support normal python types (int/float/bool) just fine:

> torch.zeros((2,2), dtype=int)
tensor([[0, 0],
        [0, 0]])

> torch.zeros((2,2), dtype=float)  
tensor([[0., 0.],
        [0., 0.]], dtype=torch.float64)

> torch.zeros((2,2), dtype=bool)   
tensor([[False, False],
        [False, False]])

Copy link

@joserapa98 joserapa98 Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there! I've been following the discussion since I ran into the empty batches problem when using poisson sampling with small sampling rate. When I tested this implementation with my own project, it didn't work for me, since it does not support more complex labels. If labels are just numbers (int, float, bool), it's ok, but if you have maybe a label formed by a tuple of numbers, its type will be tuple, thus causing an error in torch.zeros.

I don't know if you expect this kind of things to be supported, but as they work in standard PyTorch (and in fact in Opacus when batches are not empty), maybe it is something worth to be aware of.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting point, thanks

Just thinking out loud, what would be a good way to support this? In case the original label is a tuple, how does collate function handles it - would it output multiple tensors per label?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a dataset in which each sample is of type, say, tuple(torch.Tensor, tuple(int, int)), the dataloader would return tuple(torch.Tensor, tuple(torch.Tensor, torch.Tensor)), where now each tensor has an extra dimension for the batch. Something similar would happen if labels are given as lists, dicts, etc.

This is the code snippet that manages this cases:

if isinstance(elem, collections.abc.Mapping):
    try:
        return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    except TypeError:
        # The mapping type may not support `__init__(iterable)`.
        return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
    return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.


    if isinstance(elem, tuple):
        return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    else:
        try:
            return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
        except TypeError:
            # The sequence type may not support `__init__(iterable)` (e.g., `range`).
            return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]

https://github.com/pytorch/pytorch/blob/8349bf1cd1d5df7be73b194940bcf96209159f40/torch/utils/data/_utils/collate.py#L126-L149

I guess a proper solution for supporting empty batches would be to recycle this code but returning tuples/dicts/lists/... of empty tensors with torch.zeros, so that the types are still preserved, though filled with empty tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes perfect sense, thanks! You're absolutely right this is the right way forward.

However, I don't see this as a blocker for landing this PR. This PR does solve the problem for the subset of cases and could be considered an atomic improvement. Not to delay merging it, I created an issue to track the proposed improvement (#534), hopefully someone will pick it up soon

ret[layer.weight] = torch.zeros_like(layer.weight).unsqueeze(0)
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.zeros_like(layer.bias).unsqueeze(0)
return ret
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does conv need a special treatment? What happens with other layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's down to specific grad sampler implementation.
Most grad samplers we have rely on einsums, which are generally pretty good with handling 0-sized vectors.
With conv in particular the culprit is the following line:

backprops = backprops.reshape(n, -1, activations.shape[-1])

That said, it is a good point that we want to be sure all of the layers can handle it - which current tests only partially do. We have a PrivacyEngine test for an empty batch, but nothing on the grad sampler level - I'll check if there's an easy way to do it and update the PR

@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -29,7 +29,9 @@


def wrap_collate_with_empty(
collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence
collate_fn: Optional[_collate_fn_t],
sample_empty_shapes: Sequence[torch.Size],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to be torch.Size? How about List or Tuples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - I'm not aware of what the best practices are.
List os probably out of the question - it's mutable and can never be returned from shape-related methods.

torch.Size vs tuple is more interesting. Obv, passing tuple will work (after all, torch.Size is a tuple subclass with few bells and whistles), but torch will always return Size.

Given suggestive nature of typehints, does it make sense to keep torch.Size to indicate it needs to be returned from torch.Tensor.shape?

@@ -29,7 +29,9 @@


def wrap_collate_with_empty(
collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence
collate_fn: Optional[_collate_fn_t],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
collate_fn: Optional[_collate_fn_t],
*,
collate_fn: Optional[_collate_fn_t],

collate_fn: Optional[_collate_fn_t], sample_empty_shapes: Sequence
collate_fn: Optional[_collate_fn_t],
sample_empty_shapes: Sequence[torch.Size],
dtypes: Sequence[torch.dtype],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, it need not just be torch.dtype? Depending on what L85 returns.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

psolikov pushed a commit to psolikov/opacus that referenced this pull request Nov 1, 2022
Copy link
Contributor

@alexandresablayrolles alexandresablayrolles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Igor Shilov added 2 commits November 4, 2022 15:18
@facebook-github-bot
Copy link
Contributor

@ffuuugor has updated the pull request. You must reimport the pull request before landing.

@facebook-github-bot
Copy link
Contributor

@ffuuugor has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@karthikprasad karthikprasad deleted the ffuuugor_522 branch December 13, 2022 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants