Skip to content

Commit fd20954

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Add torch.utils.deterministic.fill_uninitialized_memory flag (pytorch#111377)
Part of pytorch#109802 Pull Request resolved: pytorch#111377 Approved by: https://github.com/albanD, https://github.com/aaronenyeshi
1 parent cce5016 commit fd20954

21 files changed

+193
-46
lines changed

aten/src/ATen/Context.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
6565
_deterministic_algorithms_warn_only = warn_only;
6666
}
6767

68+
bool Context::deterministicFillUninitializedMemory() const {
69+
return _deterministic_fill_uninitialized_memory;
70+
}
71+
72+
void Context::setDeterministicFillUninitializedMemory(bool b) {
73+
_deterministic_fill_uninitialized_memory = b;
74+
}
75+
6876
void Context::alertNotDeterministic(c10::string_view const& caller) {
6977
if (globalContext().deterministicAlgorithms()) {
7078
if (globalContext().deterministicAlgorithmsWarnOnly()) {

aten/src/ATen/Context.h

+3
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ class TORCH_API Context {
205205
bool deterministicAlgorithms() const;
206206
bool deterministicAlgorithmsWarnOnly() const;
207207
void setDeterministicAlgorithms(bool, bool);
208+
bool deterministicFillUninitializedMemory() const;
209+
void setDeterministicFillUninitializedMemory(bool);
208210

209211
// Note [Writing Nondeterministic Operations]
210212
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -301,6 +303,7 @@ class TORCH_API Context {
301303
bool deterministic_cudnn = false;
302304
bool _deterministic_algorithms = false;
303305
bool _deterministic_algorithms_warn_only = false;
306+
bool _deterministic_fill_uninitialized_memory = true;
304307
bool enabled_flashSDP = true;
305308
bool enabled_mem_efficientSDP = true;
306309
bool enabled_mathSDP = true;

aten/src/ATen/mps/EmptyTensor.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ TensorBase empty_mps(
6464
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
6565
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
6666
// See Note [Enabling Deterministic Operations]
67-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
67+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
6868
at::native::fill_empty_deterministic_(tensor);
6969
}
7070
return tensor;
@@ -107,7 +107,7 @@ TensorBase empty_strided_mps(
107107
Tensor result = at::detail::empty_strided_generic(
108108
size, stride, allocator, mps_dks, dtype);
109109
// See Note [Enabling Deterministic Operations]
110-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
110+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
111111
at::native::fill_empty_deterministic_(result);
112112
}
113113
return result;

aten/src/ATen/native/Resize.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ const Tensor& _resize_(
252252
self_->empty_tensor_restride(memory_format);
253253
}
254254
// See Note [Enabling Deterministic Operations]
255-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
255+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
256256
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
257257
}
258258
return self;

aten/src/ATen/native/TensorFactories.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::opt
255255
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
256256
Tensor result = at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
257257
// See Note [Enabling Deterministic Operations]
258-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
258+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
259259
fill_empty_deterministic_(result);
260260
}
261261
return result;
@@ -327,7 +327,7 @@ Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<Sca
327327
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
328328
Tensor result = at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
329329
// See Note [Enabling Deterministic Operations]
330-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
330+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
331331
fill_empty_deterministic_(result);
332332
}
333333
return result;
@@ -348,7 +348,7 @@ Tensor& empty_out(IntArrayRef size,
348348
result.resize_(size);
349349
}
350350
// See Note [Enabling Deterministic Operations]
351-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
351+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
352352
fill_empty_deterministic_(result);
353353
}
354354
return result;

aten/src/ATen/native/cuda/Resize.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ const Tensor& resize_cuda_(
6565
self_->empty_tensor_restride(memory_format);
6666
}
6767
// See Note [Enabling Deterministic Operations]
68-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
68+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
6969
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
7070
}
7171
return self;

aten/src/ATen/native/cuda/TensorFactories.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Tensor& eye_out_cuda(int64_t n, int64_t m, Tensor& result) {
5454
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
5555
Tensor result = at::detail::empty_cuda(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
5656
// See Note [Enabling Deterministic Operations]
57-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
57+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
5858
fill_empty_deterministic_(result);
5959
}
6060
return result;
@@ -80,7 +80,7 @@ Tensor _efficientzerotensor_cuda(IntArrayRef size,
8080
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
8181
Tensor result = at::detail::empty_strided_cuda(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
8282
// See Note [Enabling Deterministic Operations]
83-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
83+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
8484
fill_empty_deterministic_(result);
8585
}
8686
return result;

aten/src/ATen/native/mps/TensorFactory.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ const Tensor& resize_mps_(
120120
self_->empty_tensor_restride(memory_format);
121121
}
122122
// See Note [Enabling Deterministic Operations]
123-
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms())) {
123+
if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
124124
at::native::fill_resize_deterministic_(self, old_storage_nbytes);
125125
}
126126
return self;

docs/source/amp.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -392,4 +392,4 @@ regardless of whether autocast is enabled.
392392
.. py:module:: torch.cpu.amp.autocast_mode
393393
.. py:module:: torch.cuda.amp.autocast_mode
394394
.. py:module:: torch.cuda.amp.common
395-
.. py:module:: torch.cuda.amp.grad_scaler
395+
.. py:module:: torch.cuda.amp.grad_scaler

docs/source/deterministic.rst

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
torch.utils.deterministic
2+
=========================
3+
.. py:module:: torch.utils.deterministic
4+
.. currentmodule:: torch.utils.deterministic
5+
6+
.. attribute:: fill_uninitialized_memory
7+
8+
A :class:`bool` that, if True, causes uninitialized memory to be filled with
9+
a known value when :meth:`torch.use_deterministic_algorithms()` is set to
10+
``True``. Floating point and complex values are set to NaN, and integer
11+
values are set to the maximum value.
12+
13+
Default: ``True``
14+
15+
Filling uninitialized memory is detrimental to performance. So if your
16+
program is valid and does not use uninitialized memory as the input to an
17+
operation, then this setting can be turned off for better performance and
18+
still be deterministic.
19+
20+
The following operations will fill uninitialized memory when this setting is
21+
turned on:
22+
23+
* :func:`torch.Tensor.resize_` when called with a tensor that is not
24+
quantized
25+
* :func:`torch.empty`
26+
* :func:`torch.empty_strided`
27+
* :func:`torch.empty_permuted`
28+
* :func:`torch.empty_like`

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Features described in this documentation are classified by release status:
112112
torch.utils.checkpoint <checkpoint>
113113
torch.utils.cpp_extension <cpp_extension>
114114
torch.utils.data <data>
115+
torch.utils.deterministic <deterministic>
115116
torch.utils.jit <jit_utils>
116117
torch.utils.dlpack <dlpack>
117118
torch.utils.mobile_optimizer <mobile_optimizer>

docs/source/notes/randomness.rst

+16
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,22 @@ CUDA RNN and LSTM
144144
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
145145
See :meth:`torch.nn.RNN` and :meth:`torch.nn.LSTM` for details and workarounds.
146146

147+
Filling uninitialized memory
148+
----------------------------
149+
Operations like :meth:`torch.empty` and :meth:`torch.Tensor.resize_` can return
150+
tensors with uninitialized memory that contain undefined values. Using such a
151+
tensor as an input to another operation is invalid if determinism is required,
152+
because the output will be nondeterministic. But there is nothing to actually
153+
prevent such invalid code from being run. So for safety,
154+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is set to ``True``
155+
by default, which will fill the uninitialized memory with a known value if
156+
:code:`torch.use_deterministic_algorithms(True)` is set. This will to prevent
157+
the possibility of this kind of nondeterministic behavior.
158+
159+
However, filling uninitialized memory is detrimental to performance. So if your
160+
program is valid and does not use uninitialized memory as the input to an
161+
operation, then this setting can be turned off for better performance.
162+
147163
DataLoader
148164
..........
149165

test/test_torch.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ def test_deterministic_resize(self, device, dtype):
12921292
else:
12931293
a = torch.empty_strided(size, stride, dtype=dtype, device=device).fill_(0)
12941294
old_storage = a.untyped_storage().clone()
1295-
with DeterministicGuard(True):
1295+
with DeterministicGuard(True, fill_uninitialized_memory=True):
12961296
a.resize_(resize_size)
12971297

12981298
new_storage = a.untyped_storage()
@@ -1336,7 +1336,7 @@ def test_deterministic_empty(self, device, dtype):
13361336
]
13371337

13381338
for gen_fn in gen_fns:
1339-
with DeterministicGuard(True):
1339+
with DeterministicGuard(True, fill_uninitialized_memory=True):
13401340
res = gen_fn()
13411341

13421342
if dtype.is_floating_point or dtype.is_complex:
@@ -8689,6 +8689,38 @@ def test_deterministic_flag(self):
86898689
r"_set_deterministic_algorithms\(\): argument 'warn_only' must be bool, not int"):
86908690
torch.use_deterministic_algorithms(False, warn_only=1)
86918691

8692+
# Tests that torch.utils.deterministic.fill_uninitialized_memory can be set as expected
8693+
def test_deterministic_fill_uninitialized_memory(self):
8694+
with DeterministicGuard(True, fill_uninitialized_memory=False):
8695+
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
8696+
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
8697+
8698+
with DeterministicGuard(True, fill_uninitialized_memory=True):
8699+
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
8700+
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
8701+
8702+
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
8703+
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
8704+
8705+
torch.utils.deterministic.fill_uninitialized_memory = False
8706+
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
8707+
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
8708+
8709+
torch.utils.deterministic.fill_uninitialized_memory = True
8710+
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
8711+
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
8712+
8713+
torch._C._set_deterministic_fill_uninitialized_memory(False)
8714+
self.assertFalse(torch.utils.deterministic.fill_uninitialized_memory)
8715+
self.assertFalse(torch._C._get_deterministic_fill_uninitialized_memory())
8716+
8717+
torch._C._set_deterministic_fill_uninitialized_memory(True)
8718+
self.assertTrue(torch.utils.deterministic.fill_uninitialized_memory)
8719+
self.assertTrue(torch._C._get_deterministic_fill_uninitialized_memory())
8720+
8721+
with self.assertRaisesRegex(RuntimeError, r"expected a bool, but got int"):
8722+
torch.utils.deterministic.fill_uninitialized_memory = 1
8723+
86928724
def test_type_conversion_via_dtype_name(self):
86938725
x = torch.tensor([1])
86948726
self.assertEqual(x.byte().dtype, torch.uint8)

torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,8 @@ def _set_deterministic_algorithms(
10981098
*,
10991099
warn_only: _bool = ...,
11001100
) -> None: ... # THPModule_setDeterministicAlgorithms
1101+
def _get_deterministic_fill_uninitialized_memory() -> _bool: ... # THPModule_deterministicFillUninitializedMemory
1102+
def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ... # THPModule_setDeterministicFillUninitializedMemory
11011103
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
11021104
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
11031105
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN

torch/__init__.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -729,14 +729,6 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
729729
* :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
730730
* :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor
731731
* :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor
732-
* :func:`torch.Tensor.resize_`, when called with a tensor that is not
733-
quantized, sets new elements to a known value. Floating point or
734-
complex values are set to NaN. Integer values are set to the maximum
735-
value.
736-
* :func:`torch.empty`, :func:`torch.empty_like`, :func:`torch.empty_strided`,
737-
and :func:`torch.empty_permuted` will fill the output tensor with a known
738-
value. Floating point or complex dtype tensors are filled with NaN. Integer
739-
dtype tensors are filled with the maximum value.
740732
741733
The following normally-nondeterministic operations will throw a
742734
:class:`RuntimeError` when ``mode=True``:
@@ -781,6 +773,11 @@ def use_deterministic_algorithms(mode: builtins.bool, *, warn_only: builtins.boo
781773
* :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor
782774
* :func:`torch.Tensor.resize_` when called with a quantized tensor
783775
776+
In addition, several operations fill uninitialized memory when this setting
777+
is turned on and when
778+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on.
779+
See the documentation for that attribute for more information.
780+
784781
A handful of CUDA operations are nondeterministic if the CUDA version is
785782
10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
786783
or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more

torch/_tensor_docs.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4217,10 +4217,12 @@ def callable(a, b) -> number
42174217
42184218
.. note::
42194219
4220-
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, new
4221-
elements are initialized to prevent nondeterministic behavior from using
4222-
the result as an input to an operation. Floating point and complex values
4223-
are set to NaN, and integer values are set to the maximum value.
4220+
If :func:`torch.use_deterministic_algorithms()` and
4221+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
4222+
``True``, new elements are initialized to prevent nondeterministic behavior
4223+
from using the result as an input to an operation. Floating point and
4224+
complex values are set to NaN, and integer values are set to the maximum
4225+
value.
42244226
42254227
Args:
42264228
sizes (torch.Size or int...): the desired size

torch/_torch_docs.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -12335,11 +12335,12 @@ def merge_dicts(*dicts):
1233512335
defined by the variable argument :attr:`size`.
1233612336
1233712337
.. note::
12338-
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
12339-
output tensor is initialized to prevent any possible nondeterministic
12340-
behavior from using the data as an input to an operation. Floating point
12341-
and complex tensors are filled with NaN, and integer tensors are filled
12342-
with the maximum value.
12338+
If :func:`torch.use_deterministic_algorithms()` and
12339+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
12340+
``True``, the output tensor is initialized to prevent any possible
12341+
nondeterministic behavior from using the data as an input to an operation.
12342+
Floating point and complex tensors are filled with NaN, and integer tensors
12343+
are filled with the maximum value.
1234312344
1234412345
Args:
1234512346
size (int...): a sequence of integers defining the shape of the output tensor.
@@ -12374,11 +12375,12 @@ def merge_dicts(*dicts):
1237412375
``torch.empty(input.size(), dtype=input.dtype, layout=input.layout, device=input.device)``.
1237512376
1237612377
.. note::
12377-
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
12378-
output tensor is initialized to prevent any possible nondeterministic
12379-
behavior from using the data as an input to an operation. Floating point
12380-
and complex tensors are filled with NaN, and integer tensors are filled
12381-
with the maximum value.
12378+
If :func:`torch.use_deterministic_algorithms()` and
12379+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
12380+
``True``, the output tensor is initialized to prevent any possible
12381+
nondeterministic behavior from using the data as an input to an operation.
12382+
Floating point and complex tensors are filled with NaN, and integer tensors
12383+
are filled with the maximum value.
1238212384
1238312385
Args:
1238412386
{input}
@@ -12413,11 +12415,12 @@ def merge_dicts(*dicts):
1241312415
in memory) its behavior is undefined.
1241412416
1241512417
.. note::
12416-
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
12417-
output tensor is initialized to prevent any possible nondeterministic
12418-
behavior from using the data as an input to an operation. Floating point
12419-
and complex tensors are filled with NaN, and integer tensors are filled
12420-
with the maximum value.
12418+
If :func:`torch.use_deterministic_algorithms()` and
12419+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
12420+
``True``, the output tensor is initialized to prevent any possible
12421+
nondeterministic behavior from using the data as an input to an operation.
12422+
Floating point and complex tensors are filled with NaN, and integer tensors
12423+
are filled with the maximum value.
1242112424
1242212425
Args:
1242312426
size (tuple of int): the shape of the output tensor
@@ -12465,11 +12468,12 @@ def merge_dicts(*dicts):
1246512468
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
1246612469
1246712470
.. note::
12468-
If :func:`torch.use_deterministic_algorithms()` is set to ``True``, the
12469-
output tensor is initialized to prevent any possible nondeterministic
12470-
behavior from using the data as an input to an operation. Floating point
12471-
and complex tensors are filled with NaN, and integer tensors are filled
12472-
with the maximum value.
12471+
If :func:`torch.use_deterministic_algorithms()` and
12472+
:attr:`torch.utils.deterministic.fill_uninitialized_memory` are both set to
12473+
``True``, the output tensor is initialized to prevent any possible
12474+
nondeterministic behavior from using the data as an input to an operation.
12475+
Floating point and complex tensors are filled with NaN, and integer tensors
12476+
are filled with the maximum value.
1247312477
1247412478
Args:
1247512479
size (tuple of int): the shape of the output tensor

0 commit comments

Comments
 (0)