Skip to content

Commit 6620176

Browse files
ezyangpytorchmergebot
authored andcommitted
Add documentation for meta device (pytorch#119119)
Fixes pytorch#119098 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#119119 Approved by: https://github.com/bdhirsh
1 parent dab16b6 commit 6620176

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Features described in this documentation are classified by release status:
6969
torch.cuda.memory <torch_cuda_memory>
7070
mps
7171
xpu
72+
meta
7273
torch.backends <backends>
7374
torch.export <export>
7475
torch.distributed <distributed>

docs/source/meta.rst

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
Meta device
2+
============
3+
4+
The "meta" device is an abstract device which denotes a tensor which records
5+
only metadata, but no actual data. Meta tensors have two primary use cases:
6+
7+
* Models can be loaded on the meta device, allowing you to load a
8+
representation of the model without actually loading the actual parameters
9+
into memory. This can be helpful if you need to make transformations on
10+
the model before you load the actual data.
11+
12+
* Most operations can be performed on meta tensors, producing new meta
13+
tensors that describe what the result would have been if you performed
14+
the operation on a real tensor. You can use this to perform abstract
15+
analysis without needing to spend time on compute or space to represent
16+
the actual tensors. Because meta tensors do not have real data, you cannot
17+
perform data-dependent operations like :func:`torch.nonzero` or
18+
:meth:`~torch.Tensor.item`. In some cases, not all device types (e.g., CPU
19+
and CUDA) have exactly the same output metadata for an operation; we
20+
typically prefer representing the CUDA behavior faithfully in this
21+
situation.
22+
23+
.. warning::
24+
25+
Although in principle meta tensor computation should always be faster than
26+
an equivalent CPU/CUDA computation, many meta tensor implementations are
27+
implemented in Python and have not been ported to C++ for speed, so you
28+
may find that you get lower absolute framework latency with small CPU tensors.
29+
30+
Idioms for working with meta tensors
31+
------------------------------------
32+
33+
An object can be loaded with :func:`torch.load` onto meta device by specifying
34+
``map_location='meta'``::
35+
36+
>>> torch.save(torch.randn(2), 'foo.pt')
37+
>>> torch.load('foo.pt', map_location='meta')
38+
tensor(..., device='meta', size=(2,))
39+
40+
If you have some arbitrary code which performs some tensor construction without
41+
explicitly specifying a device, you can override it to instead construct on meta device by using
42+
the :func:`torch.device` context manager::
43+
44+
>>> with torch.device('meta'):
45+
... print(torch.randn(30, 30))
46+
...
47+
tensor(..., device='meta', size=(30, 30))
48+
49+
This is especially helpful NN module construction, where you often are not
50+
able to explicitly pass in a device for initialization::
51+
52+
>>> from torch.nn.modules import Linear
53+
>>> with torch.device('meta'):
54+
... print(Linear(20, 30))
55+
...
56+
Linear(in_features=20, out_features=30, bias=True)
57+
58+
You cannot convert a meta tensor directly to a CPU/CUDA tensor, because the
59+
meta tensor stores no data and we do not know what the correct data values for
60+
your new tensor are::
61+
62+
>>> torch.ones(5, device='meta').to("cpu")
63+
Traceback (most recent call last):
64+
File "<stdin>", line 1, in <module>
65+
NotImplementedError: Cannot copy out of meta tensor; no data!
66+
67+
Use a factory function like :func:`torch.empty_like` to explicitly specify how
68+
you would like the missing data to be filled in.
69+
70+
NN modules have a convenience method :meth:`torch.nn.Module.to_empty` that
71+
allow you to the module to another device, leaving all parameters
72+
uninitialized. You are expected to explicitly reinitialize the parameters
73+
manually::
74+
75+
>>> from torch.nn.modules import Linear
76+
>>> with torch.device('meta'):
77+
... m = Linear(20, 30)
78+
>>> m.to_empty(device="cpu")
79+
Linear(in_features=20, out_features=30, bias=True)
80+
81+
:mod:`torch._subclasses.meta_utils` contains undocumented utilities for taking
82+
an arbitrary Tensor and constructing an equivalent meta Tensor with high
83+
fidelity. These APIs are experimental and may be changed in a BC breaking way
84+
at any time.

docs/source/tensor_attributes.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ torch.device
139139
A :class:`torch.device` is an object representing the device on which a :class:`torch.Tensor` is
140140
or will be allocated.
141141

142-
The :class:`torch.device` contains a device type (``'cpu'``, ``'cuda'`` or ``'mps'``) and optional device
143-
ordinal for the device type. If the device ordinal is not present, this object will always represent
142+
The :class:`torch.device` contains a device type (most commonly "cpu" or
143+
"cuda", but also potentially :doc:`"mps" <mps>`, :doc:`"xpu" <xpu>`,
144+
`"xla" <https://github.com/pytorch/xla/>`_ or :doc:`"meta" <meta>`) and optional
145+
device ordinal for the device type. If the device ordinal is not present, this object will always represent
144146
the current device for the device type, even after :func:`torch.cuda.set_device()` is called; e.g.,
145147
a :class:`torch.Tensor` constructed with device ``'cuda'`` is equivalent to ``'cuda:X'`` where X is
146148
the result of :func:`torch.cuda.current_device()`.

0 commit comments

Comments
 (0)