Skip to content

Commit 07d3bdd

Browse files
nicholas-leonardsoumith
authored andcommitted
nn.GPU (#835)
* Added nn.GPU
1 parent 2207e45 commit 07d3bdd

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

GPU.lua

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
------------------------------------------------------------------------
2+
--[[ GPU ]]--
3+
-- Decorates a module such that its parameters are
4+
-- hosted on a specified GPU device.
5+
-- The operations are also executed on that device.
6+
-- Arguments input and gradOutput are converted to the specified device
7+
-- before being fed to the decorated module.
8+
-- Returned output is on the specified outdevice (defaults to device).
9+
-- Returned gradInput is allocated on the same device as the input.
10+
-- The unit test is located in cunn.
11+
------------------------------------------------------------------------
12+
local GPU, parent = torch.class("nn.GPU", "nn.Container")
13+
14+
function GPU:__init(module, device, outdevice)
15+
parent.__init(self)
16+
assert(torch.type(device) == 'number')
17+
self.device = device
18+
self.outdevice = outdevice or device
19+
20+
assert(torch.isTypeOf(module, 'nn.Module'))
21+
self.modules[1] = module
22+
23+
if module:type() == 'torch.CudaTensor' then
24+
self:cuda()
25+
end
26+
end
27+
28+
function GPU.recursiveModuleDevice(obj, device)
29+
if type(obj) == 'table' and not torch.isTypeOf(obj, 'nn.GPU') and not obj.__noGPU__ then
30+
for k,v in pairs(obj) do
31+
obj[k] = GPU.recursiveModuleDevice(v, device)
32+
end
33+
elseif torch.type(obj):match('torch.Cuda.*Tensor') then
34+
if obj:getDevice() ~= device then
35+
obj = obj:clone() -- this will reallocate it to device
36+
local newdevice = obj:getDevice()
37+
-- when nElement() == 0 newdevice is 0
38+
assert(newdevice == device or newdevice == 0)
39+
end
40+
end
41+
assert(obj ~= nil)
42+
return obj
43+
end
44+
45+
-- set the device of the decorated module
46+
function GPU:setDevice(device)
47+
self.device = device or self.device
48+
49+
assert(self.modules[1])
50+
self.modules[1] = cutorch.withDevice(self.device, function()
51+
return self.recursiveModuleDevice(self.modules[1], self.device)
52+
end)
53+
return self
54+
end
55+
56+
-- when proto is a device number, returns a dst that has device device for each element in src
57+
-- otherwise, if proto is a table/tensor, makes sure dst is a identical to src, yet on the same device as proto
58+
function GPU.recursiveSetDevice(dst, src, proto)
59+
local device, prototable
60+
if torch.isTensor(proto) then
61+
device = proto:getDevice()
62+
elseif torch.type(proto) == 'number' then
63+
device = proto
64+
elseif torch.type(proto) == 'table' then
65+
prototable = true
66+
else
67+
error"Expecting number, table or tensor for arg 3 (proto)"
68+
end
69+
if torch.type(src) == 'table' then
70+
dst = torch.type(dst) == 'table' and dst or {}
71+
for k,v in ipairs(src) do
72+
dst[k] = GPU.recursiveSetDevice(dst[k], v, prototable and proto[k] or device)
73+
end
74+
for k=#src+1,#dst do
75+
dst[k] = nil
76+
end
77+
elseif torch.type(src):match('torch.Cuda.*Tensor') and src:getDevice() ~= device and src:getDevice() ~= 0 then
78+
if not (torch.type(dst):match('torch.Cuda.*Tensor') and dst:getDevice() == device) then
79+
dst = src.new()
80+
end
81+
cutorch.withDevice(device, function() dst:resizeAs(src):copy(src) end)
82+
else
83+
dst = src
84+
end
85+
return dst
86+
end
87+
88+
function GPU:updateOutput(input)
89+
if self._type == 'torch.CudaTensor' then
90+
self._input = self.recursiveSetDevice(self._input, input, self.device)
91+
92+
local output = cutorch.withDevice(self.device, function()
93+
return self.modules[1]:updateOutput(self._input)
94+
end)
95+
96+
if self.device ~= self.outdevice then
97+
self.output = self.recursiveSetDevice(self.output, output, self.outdevice)
98+
else
99+
self.output = output
100+
end
101+
else
102+
self.output = self.modules[1]:updateOutput(input)
103+
end
104+
105+
return self.output
106+
end
107+
108+
function GPU:updateGradInput(input, gradOutput)
109+
if self._type == 'torch.CudaTensor' then
110+
self._gradOutput = self.recursiveSetDevice(self._gradOutput, gradOutput, self.device)
111+
112+
local gradInput = cutorch.withDevice(self.device, function()
113+
return self.modules[1]:updateGradInput(self._input, self._gradOutput)
114+
end)
115+
116+
self.gradInput = self.recursiveSetDevice(self.gradInput, gradInput, input)
117+
else
118+
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)
119+
end
120+
121+
return self.gradInput
122+
end
123+
124+
function GPU:accGradParameters(input, gradOutput, scale)
125+
if self._type == 'torch.CudaTensor' then
126+
cutorch.withDevice(self.device, function()
127+
self.modules[1]:accGradParameters(self._input, self._gradOutput, scale)
128+
end)
129+
else
130+
self.modules[1]:accGradParameters(input, gradOutput, scale)
131+
end
132+
end
133+
134+
function GPU:apply(callback)
135+
if self._type == 'torch.CudaTensor' then
136+
cutorch.withDevice(self.device, function() parent.apply(self, callback) end)
137+
else
138+
parent.apply(self, callback)
139+
end
140+
end
141+
142+
function GPU:type(type, typecache)
143+
if type and type == 'torch.CudaTensor' then
144+
cutorch.withDevice(self.device, function() parent.type(self, type, typecache) end)
145+
self:setDevice()
146+
else
147+
self.output = nil
148+
self.gradInput = nil
149+
self._input = nil
150+
self._gradOutput = nil
151+
parent.type(self, type, typecache)
152+
end
153+
return self
154+
end
155+
156+
function GPU:clearState()
157+
nn.utils.clear(self, 'output', 'gradInput')
158+
self._input = nil
159+
self._gradOutput = nil
160+
if self._type == 'torch.CudaTensor' then
161+
cutorch.withDevice(self.device, function() parent.clearState(self) end)
162+
else
163+
parent.clearState(self)
164+
end
165+
end
166+
167+
function GPU:zeroGradParameters()
168+
if self._type == 'torch.CudaTensor' then
169+
cutorch.withDevice(self.device, function() parent.zeroGradParameters(self) end)
170+
else
171+
parent.zeroGradParameters(self)
172+
end
173+
end
174+
175+
function GPU:updateParameters(lr)
176+
if self._type == 'torch.CudaTensor' then
177+
cutorch.withDevice(self.device, function() parent.updateParameters(self, lr) end)
178+
else
179+
parent.updateParameters(self, lr)
180+
end
181+
end
182+
183+
function GPU:training()
184+
if self._type == 'torch.CudaTensor' then
185+
cutorch.withDevice(self.device, function() parent.training(self) end)
186+
else
187+
parent.training(self)
188+
end
189+
end
190+
191+
function GPU:evaluate()
192+
if self._type == 'torch.CudaTensor' then
193+
cutorch.withDevice(self.device, function() parent.evaluate(self) end)
194+
else
195+
parent.evaluate(self)
196+
end
197+
end
198+
199+
function GPU:share(mlp, ...)
200+
local args = {...}
201+
if self._type == 'torch.CudaTensor' then
202+
cutorch.withDevice(self.device, function() parent.share(self, mlp, unpack(args)) end)
203+
else
204+
parent.share(self, mlp, unpack(args))
205+
end
206+
return self
207+
end
208+
209+
function GPU:reset(...)
210+
local args = {...}
211+
if self._type == 'torch.CudaTensor' then
212+
cutorch.withDevice(self.device, function() parent.reset(self, unpack(args)) end)
213+
else
214+
parent.reset(self, unpack(args))
215+
end
216+
return self
217+
end
218+
219+
function GPU:clone(...)
220+
local args = {...}
221+
if self._type == 'torch.CudaTensor' then
222+
return cutorch.withDevice(self.device, function() parent.clone(self, unpack(args)) end)
223+
else
224+
return parent.clone(self, unpack(args))
225+
end
226+
end
227+
228+
function GPU:write(file)
229+
-- Write all values in the object as a table.
230+
local object = {}
231+
for k, v in pairs(self) do
232+
object[k] = v
233+
end
234+
local header = {self._type, self.device}
235+
file:writeObject(header)
236+
file:writeObject(object)
237+
end
238+
239+
function GPU:read(file)
240+
local header = file:readObject()
241+
local object
242+
if header[1] == 'torch.CudaTensor' then
243+
local device = header[2]
244+
if device > cutorch.getDeviceCount() then
245+
print"Warning : model was saved with more devices than available on current host."
246+
print"Attempting to load module onto device 1"
247+
device = 1
248+
end
249+
object = cutorch.withDevice(device, function() return file:readObject() end)
250+
else
251+
object = file:readObject()
252+
end
253+
254+
for k, v in pairs(object) do
255+
self[k] = v
256+
end
257+
end
258+
259+
function GPU:__tostring__()
260+
if self.modules[1].__tostring__ then
261+
return torch.type(self) .. '(' .. self.device ..') @ ' .. self.modules[1]:__tostring__()
262+
else
263+
return torch.type(self) .. '(' .. self.device ..') @ ' .. torch.type(self.modules[1])
264+
end
265+
end
266+
267+
function GPU:accUpdateGradParameters(input, gradOutput, lr)
268+
error("Not Implemented for "..torch.type(self))
269+
end
270+
271+
function GPU:sharedAccUpdateGradParameters(input, gradOutput, lr)
272+
error("Not Implemented for "..torch.type(self))
273+
end

doc/simple.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
5151
* [Padding](#nn.Padding) : adds padding to a dimension ;
5252
* [L1Penalty](#nn.L1Penalty) : adds an L1 penalty to an input (for sparsity) ;
5353
* [GradientReversal](#nn.GradientReversal) : reverses the gradient (to maximize an objective function) ;
54+
* [GPU](#nn.GPU) : decorates a module so that it can be executed on a specific GPU device.
5455

5556
<a name="nn.Linear"></a>
5657
## Linear ##
@@ -1404,3 +1405,50 @@ One can also call:
14041405
module:setLambda(lambda)
14051406
```
14061407
to set the hyper-parameter `lambda` dynamically during training.
1408+
1409+
<a name="nn.GPU"></a>
1410+
## GPU ##
1411+
1412+
```lua
1413+
gpu = nn.GPU(module, device, [outdevice])
1414+
require 'cunn'
1415+
gpu:cuda()
1416+
```
1417+
1418+
Decorates an encapsulated `module` so that it can be executed on a specific GPU `device`.
1419+
The decorated module's `parameters` are thus hosted on the specified GPU `device`.
1420+
All operations on the `gpu` module are executed on that device.
1421+
Calls to `forward`/`backward` will transfer arguments `input` and `gradOutput` to the specified `device`,
1422+
which are then fed as arguments to the decorated `module`.
1423+
Returned `output` is located on the specified `outdevice` (defaults to `device`).
1424+
Returned `gradInput` is allocated on the same device as the `input`.
1425+
1426+
When serialized/deserialized, the `gpu` module will be run on the same `device` that it was serialized with.
1427+
To prevent this from happening, the module can be converted to float/double before serialization:
1428+
1429+
```lua
1430+
gpu:float()
1431+
gpustr = torch.serialize(gpu)
1432+
```
1433+
1434+
The module is located in the __nn__ package instead of __cunn__ as this allows
1435+
it to be used in CPU-only enviroments, which are common for production models.
1436+
1437+
The module supports nested table `input` and `gradOutput` tensors originating from multiple devices.
1438+
Each nested tensor in the returned `gradInput` will be transfered to the device its commensurate tensor in the `input`.
1439+
1440+
The intended use-case is not for model-parallelism where the models are executed in parallel on multiple devices, but
1441+
for sequential models where a single GPU doesn't have enough memory.
1442+
1443+
Example using 4 GPUs:
1444+
1445+
```lua
1446+
mlp = nn.Sequential()
1447+
:add(nn.GPU(nn.Linear(10000,10000), 1))
1448+
:add(nn.GPU(nn.Linear(10000,10000), 2))
1449+
:add(nn.GPU(nn.Linear(10000,10000), 3))
1450+
:add(nn.GPU(nn.Linear(10000,10000), 4, cutorch.getDevice()))
1451+
```
1452+
1453+
Note how the last `GPU` instance will return an `output` tensor on the same device as the current device (`cutorch.getDevice`).
1454+

init.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ require('nn.VolumetricAveragePooling')
126126
require('nn.VolumetricBatchNormalization')
127127
require('nn.VolumetricReplicationPadding')
128128

129+
require('nn.GPU')
130+
129131
require('nn.ParallelTable')
130132
require('nn.Identity')
131133
require('nn.ConcatTable')

test.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6359,6 +6359,11 @@ function nntest.ErrorHandling()
63596359
)
63606360
end
63616361

6362+
function nntest.GPU()
6363+
-- this is a placeholder to let you know that the nn.GPU unit test
6364+
-- is located in cunn package.
6365+
end
6366+
63626367
mytester:add(nntest)
63636368

63646369
jac = nn.Jacobian

0 commit comments

Comments
 (0)