Skip to content

Commit d6cbecb

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch] Reapply D27404164: Devirtualize is_contiguous (pytorch#55333)
Summary: Pull Request resolved: pytorch#55333 Reapplying without using enum class in a bitfield. See new comments about gcc bug. ghstack-source-id: 125776904 Test Plan: Carefully review OSS test failure logs this time Reviewed By: kimishpatel, bhosmer Differential Revision: D27576623 fbshipit-source-id: 68fb00e5ff5215e56c8b9bc02717e1e7b2fedf9b
1 parent e359842 commit d6cbecb

9 files changed

+83
-32
lines changed

Diff for: aten/src/ATen/BatchedTensorImpl.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, BatchDims bdims)
1616
{
1717
TORCH_INTERNAL_ASSERT(value_.defined());
1818
set_storage_access_should_throw();
19+
set_has_contiguity_policy(HasContiguityPolicy::CustomBehavior);
1920
checkInvariants();
2021

2122
const auto public_dims = value_.dim() - bdims_.size();
@@ -74,7 +75,7 @@ void BatchedTensorImpl::checkInvariants() const {
7475
}
7576

7677
// The following are publically exposed as methods of Tensor
77-
bool BatchedTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
78+
bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
7879
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
7980
"NYI: querying is_contiguous inside of vmap for memory_format ",
8081
"other than torch.contiguous_format");

Diff for: aten/src/ATen/BatchedTensorImpl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
7373
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
7474

7575
// Override a bunch of methods inherited from TensorImpl to return error messages.
76-
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
76+
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
7777
void set_size(int64_t dim, int64_t new_size) override;
7878
void set_stride(int64_t dim, int64_t new_stride) override;
7979
void set_storage_offset(int64_t storage_offset) override;

Diff for: aten/src/ATen/OpaqueTensorImpl.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
2929
: TensorImpl(key_set, data_type, device),
3030
opaque_handle_(std::move(opaque_handle)) {
3131
set_storage_access_should_throw();
32+
set_has_contiguity_policy(HasContiguityPolicy::ContiguityNotSupported);
3233
sizes_and_strides_.set_sizes(sizes);
3334
refresh_numel();
3435
is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
@@ -43,12 +44,6 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl {
4344
AT_ERROR("opaque tensors do not have strides");
4445
}
4546

46-
bool is_contiguous(
47-
c10::MemoryFormat memory_format =
48-
c10::MemoryFormat::Contiguous) const override {
49-
AT_ERROR("opaque tensors do not have is_contiguous");
50-
}
51-
5247
int64_t stride(int64_t d) const override {
5348
AT_ERROR("opaque tensors do not have strides");
5449
}

Diff for: aten/src/ATen/SparseTensorImpl.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::Typ
5151

5252
is_non_overlapping_and_dense_ = false;
5353
set_storage_access_should_throw();
54+
set_has_contiguity_policy(HasContiguityPolicy::ContiguityNotSupported);
5455
}
5556

5657
void SparseTensorImpl::release_resources() {
@@ -62,9 +63,6 @@ void SparseTensorImpl::release_resources() {
6263
IntArrayRef SparseTensorImpl::strides() const {
6364
AT_ERROR("sparse tensors do not have strides");
6465
}
65-
bool SparseTensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
66-
AT_ERROR("sparse tensors do not have is_contiguous");
67-
}
6866
int64_t SparseTensorImpl::stride(int64_t d) const {
6967
AT_ERROR("sparse tensors do not have strides");
7068
}

Diff for: aten/src/ATen/SparseTensorImpl.h

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ struct TORCH_API SparseTensorImpl : public TensorImpl {
4343
Tensor values() const { return values_; }
4444

4545
IntArrayRef strides() const override;
46-
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
4746
int64_t stride(int64_t d) const override;
4847
void set_size(int64_t dim, int64_t new_size) override;
4948
void set_stride(int64_t dim, int64_t new_stride) override;

Diff for: aten/src/ATen/native/metal/MetalTensorImpl.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
2222
device,
2323
opaque_handle,
2424
sizes),
25-
strides_(strides.vec()) {}
25+
strides_(strides.vec()) {
26+
TensorImpl::set_has_contiguity_policy(TensorImpl::HasContiguityPolicy::CustomBehavior);
27+
}
2628

2729
IntArrayRef strides() const override {
2830
return strides_;
2931
}
3032

31-
bool is_contiguous(
32-
c10::MemoryFormat memory_format =
33-
c10::MemoryFormat::Contiguous) const override {
33+
bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
3434
return true;
3535
}
3636

Diff for: aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
2323
opaque_handle,
2424
sizes,
2525
false),
26-
strides_(strides.vec()) {}
26+
strides_(strides.vec()) {
27+
TensorImpl::set_has_contiguity_policy(TensorImpl::HasContiguityPolicy::CustomBehavior);
28+
}
2729

2830
IntArrayRef strides() const override {
2931
return strides_;
3032
}
3133

32-
bool is_contiguous(
33-
c10::MemoryFormat memory_format =
34-
c10::MemoryFormat::Contiguous) const override {
34+
bool is_contiguous_custom(c10::MemoryFormat memory_format) const override {
3535
return true;
3636
}
3737

Diff for: c10/core/TensorImpl.cpp

+16-10
Original file line numberDiff line numberDiff line change
@@ -276,17 +276,22 @@ void TensorImpl::throw_storage_access_error() const {
276276
TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot access storage of ", tensorimpl_type_name());
277277
}
278278

279-
bool TensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
280-
#ifdef DEBUG
281-
AT_ASSERT(compute_contiguous() == is_contiguous_);
282-
#endif
283-
if (memory_format == at::MemoryFormat::ChannelsLast) {
284-
return is_channels_last_contiguous_;
285-
}
286-
else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
287-
return is_channels_last_3d_contiguous_;
279+
bool TensorImpl::is_contiguous_nondefault_policy_impl(at::MemoryFormat memory_format) const {
280+
if (has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
281+
TORCH_CHECK_NOT_IMPLEMENTED(
282+
false, "Tensors of type ", tensorimpl_type_name(),
283+
" do not have is_contiguous");
284+
} else {
285+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(has_contiguity_ == static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
286+
return is_contiguous_custom(memory_format);
288287
}
289-
return is_contiguous_;
288+
}
289+
290+
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
291+
TORCH_INTERNAL_ASSERT(
292+
false,
293+
"TensorImpl::is_contiguous_custom should never be called; did you "
294+
"set_has_contiguity_policy and forget to override is_contiguous_custom?");
290295
}
291296

292297
static void deletePlacementDeleteContext(void* ptr) {
@@ -381,6 +386,7 @@ void TensorImpl::copy_tensor_metadata_except_version_counter(
381386
dest_impl->device_opt_ = src_impl->device_opt_;
382387
dest_impl->key_set_ = src_impl->key_set_;
383388
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
389+
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
384390
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
385391
dest_impl->is_channels_last_3d_contiguous_ = src_impl->is_channels_last_3d_contiguous_;
386392
dest_impl->is_channels_last_ = src_impl->is_channels_last_;

Diff for: c10/core/TensorImpl.h

+54-2
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,37 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
483483
* Tensors with non-trivial strides are not contiguous. See
484484
* compute_contiguous() for the exact definition of whether or not
485485
* a tensor is contiguous or not.
486+
*
487+
* NOTE: is_contiguous is only `TENSORIMPL_MAYBE_VIRTUAL` for
488+
* backward compatibility. See `set_has_contiguity_policy` and
489+
* `is_contiguous_custom` for the encouraged customization point.
486490
*/
487-
virtual bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const;
491+
TENSORIMPL_MAYBE_VIRTUAL bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
492+
if (C10_UNLIKELY(has_contiguity_ != static_cast<uint8_t>(HasContiguityPolicy::Default))) {
493+
return is_contiguous_nondefault_policy_impl(memory_format);
494+
}
495+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(compute_contiguous() == is_contiguous_);
496+
if (memory_format == at::MemoryFormat::ChannelsLast) {
497+
return is_channels_last_contiguous_;
498+
}
499+
else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
500+
return is_channels_last_3d_contiguous_;
501+
}
502+
return is_contiguous_;
503+
}
488504

505+
private:
506+
bool is_contiguous_nondefault_policy_impl(at::MemoryFormat) const;
507+
508+
protected:
509+
/**
510+
* Customization point for is_contiguous; must also
511+
* set_has_contiguity_policy(HasContiguityPolicy::Custom) for this
512+
* to be called.
513+
*/
514+
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
515+
516+
public:
489517
bool is_sparse() const {
490518
// NB: This method is not virtual and avoid dispatches for performance reasons.
491519
return key_set_.has(DispatchKey::SparseCPU) ||
@@ -1725,6 +1753,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
17251753
}
17261754

17271755
protected:
1756+
// Policy for adjusting the behavior of is_contiguous(). Allows
1757+
// subclass customization while still being able to inline
1758+
// is_contiguous() in the common case.
1759+
enum class HasContiguityPolicy : uint8_t {
1760+
// Default behavior: check is_contiguous_ and similar bitflags.
1761+
Default,
1762+
// Throw a generic error message that this tensor type does not
1763+
// support is_contiguous.
1764+
ContiguityNotSupported,
1765+
// Call virtual is_contiguous_custom method to implement custom
1766+
// is_contiguous behavior.
1767+
CustomBehavior,
1768+
};
1769+
1770+
void set_has_contiguity_policy(HasContiguityPolicy p) {
1771+
has_contiguity_ = static_cast<uint8_t>(p);
1772+
}
1773+
17281774
Storage storage_;
17291775

17301776
private:
@@ -1801,13 +1847,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
18011847
c10::optional<c10::Device> device_opt_;
18021848

18031849
// Tensor is contiguous
1804-
bool is_contiguous_ = true;
1850+
bool is_contiguous_ : 1;
1851+
// gcc doesn't like enum class bitfields; see
1852+
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414
1853+
/* HasContiguityPolicy */ uint8_t has_contiguity_ : 2;
18051854

18061855
// Tensor is a subclass that does not permit storage access.
18071856
bool storage_access_should_throw_ = false;
18081857

18091858
// default member initializers for bit-fields only available with -std=c++2a or -std=gnu++2a
18101859
inline void init_bitfields() {
1860+
is_contiguous_ = true;
1861+
has_contiguity_ = static_cast<uint8_t>(HasContiguityPolicy::Default);
1862+
18111863
is_channels_last_ = false;
18121864
is_channels_last_contiguous_ = false;
18131865
is_channels_last_3d_ = false;

0 commit comments

Comments
 (0)