Skip to content

Commit fd20954

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Add torch.utils.deterministic.fill_uninitialized_memory flag (pytorch#111377)
Part of pytorch#109802 Pull Request resolved: pytorch#111377 Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
1 parent cce5016 commit fd20954

21 files changed

+193
-46
lines changed

aten/src/ATen/Context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
6565
_deterministic_algorithms_warn_only = warn_only;
6666
}
6767

68+
bool Context::deterministicFillUninitializedMemory() const {
69+
return _deterministic_fill_uninitialized_memory;
70+
}
71+
72+
void Context::setDeterministicFillUninitializedMemory(bool b) {
73+
_deterministic_fill_uninitialized_memory = b;
74+
}
75+
6876
void Context::alertNotDeterministic(c10::string_view const& caller) {
6977
if (globalContext().deterministicAlgorithms()) {
7078
if (globalContext().deterministicAlgorithmsWarnOnly()) {

aten/src/ATen/Context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TORCH_API Context {
205205
bool deterministicAlgorithms() const;
206206
bool deterministicAlgorithmsWarnOnly() const;
207207
void setDeterministicAlgorithms(bool, bool);
208+
bool deterministicFillUninitializedMemory() const;
209+
void setDeterministicFillUninitializedMemory(bool);
208210

209211
// Note [Writing Nondeterministic Operations]
210212
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -301,6 +303,7 @@ class TORCH_API Context {
301303
bool deterministic_cudnn = false;
302304
bool _deterministic_algorithms = false;
303305
bool _deterministic_algorithms_warn_only = false;
306+
bool _deterministic_fill_uninitialized_memory = true;
304307
bool enabled_flashSDP = true;
305308
bool enabled_mem_efficientSDP = true;
306309
bool enabled_mathSDP = true;

aten/src/ATen/mps/EmptyTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ TensorBase empty_mps(
6464
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
6565
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
6666
// See Note [Enabling Deterministic Operations]
67-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
67+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
6868
at::native::fill_empty_deterministic_(tensor);
6969
}
7070
return tensor;
@@ -107,7 +107,7 @@ TensorBase empty_strided_mps(
107107
Tensor result = at::detail::empty_strided_generic(
108108
size, stride, allocator, mps_dks, dtype);
109109
// See Note [Enabling Deterministic Operations]
110-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
110+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
111111
at::native::fill_empty_deterministic_(result);
112112
}
113113
return result;

aten/src/ATen/native/Resize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ const Tensor& _resize_(
252252
self_->empty_tensor_restride(memory_format);
253253
}
254254
// See Note [Enabling Deterministic Operations]
255-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
255+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
256256
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
257257
}
258258
return self;

aten/src/ATen/native/TensorFactories.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
255255
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
256256
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
257257
// See Note [Enabling Deterministic Operations]
258-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
258+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
259259
fill_empty_deterministic_(result);
260260
}
261261
return result;
@@ -327,7 +327,7 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
327327
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
328328
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
329329
// See Note [Enabling Deterministic Operations]
330-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
330+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
331331
fill_empty_deterministic_(result);
332332
}
333333
return result;
@@ -348,7 +348,7 @@ Tensor& empty_out(IntArrayRef size,
348348
result.resize_(size);
349349
}
350350
// See Note [Enabling Deterministic Operations]
351-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
351+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
352352
fill_empty_deterministic_(result);
353353
}
354354
return result;

aten/src/ATen/native/cuda/Resize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ const Tensor& resize_cuda_(
6565
self_->empty_tensor_restride(memory_format);
6666
}
6767
// See Note [Enabling Deterministic Operations]
68-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
68+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
6969
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
7070
}
7171
return self;

aten/src/ATen/native/cuda/TensorFactories.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
5454
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
5555
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
5656
// See Note [Enabling Deterministic Operations]
57-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
57+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
5858
fill_empty_deterministic_(result);
5959
}
6060
return result;
@@ -80,7 +80,7 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
8080
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
8181
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
8282
// See Note [Enabling Deterministic Operations]
83-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
83+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
8484
fill_empty_deterministic_(result);
8585
}
8686
return result;

aten/src/ATen/native/mps/TensorFactory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ const Tensor& resize_mps_(
120120
self_->empty_tensor_restride(memory_format);
121121
}
122122
// See Note [Enabling Deterministic Operations]
123-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
123+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
124124
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
125125
}
126126
return self;

docs/source/amp.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,4 +392,4 @@ regardless of whether autocast is enabled.
392392
.. py:module:: torch.cpu.amp.autocast_mode
393393
.. py:module:: torch.cuda.amp.autocast_mode
394394
.. py:module:: torch.cuda.amp.common
395-
.. py:module:: torch.cuda.amp.grad_scaler
395+
.. py:module:: torch.cuda.amp.grad_scaler

docs/source/deterministic.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
torch.utils.deterministic
2+
=========================
3+
.. py:module:: torch.utils.deterministic
4+
.. currentmodule:: torch.utils.deterministic
5+
6+
.. attribute:: fill_uninitialized_memory
7+
8+
A :class:`bool` that, if True, causes uninitialized memory to be filled with
9+
a known value when :meth:`torch.use_deterministic_algorithms()` is set to
10+
``True``. Floating point and complex values are set to NaN, and integer
11+
values are set to the maximum value.
12+
13+
Default: ``True``
14+
15+
Filling uninitialized memory is detrimental to performance. So if your
16+
program is valid and does not use uninitialized memory as the input to an
17+
operation, then this setting can be turned off for better performance and
18+
still be deterministic.
19+
20+
The following operations will fill uninitialized memory when this setting is
21+
turned on:
22+
23+
* :func:`torch.Tensor.resize_` when called with a tensor that is not
24+
quantized
25+
* :func:`torch.empty`
26+
* :func:`torch.empty_strided`
27+
* :func:`torch.empty_permuted`
28+
* :func:`torch.empty_like`

0 commit comments

Comments
 (0)