From 40cb39f692267cc9a5982996cc18216e37320726 Mon Sep 17 00:00:00 2001 From: Laurens van der Maaten Date: Tue, 21 Feb 2017 14:18:10 -0500 Subject: [PATCH] [DPT] Add option for synchronous copy of model to GPUs --- DataParallelTable.lua | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/DataParallelTable.lua b/DataParallelTable.lua index 9e07978c..22c77394 100644 --- a/DataParallelTable.lua +++ b/DataParallelTable.lua @@ -85,12 +85,12 @@ function DataParallelTable:add(module, gpus) return self end -function DataParallelTable:threads(initFunc) +function DataParallelTable:threads(initFunc, syncCopies) require 'threads' self.impl:close() - self.impl = Impls.Threads(self, initFunc) + self.impl = Impls.Threads(self, initFunc, syncCopies) return self -end +end -- NOTE: Setting syncCopies will copy model to GPUs synchronously. function DataParallelTable:__tostring() return 'DataParallelTable: ' .. #self.gpuAssignments .. ' x ' .. tostring(self.modules[1]) @@ -678,12 +678,15 @@ function BasicImpl:close() end -- Multi-threaded dispatch -function ThreadsImpl:__init(dpt, initFunc) +function ThreadsImpl:__init(dpt, initFunc, syncCopies) self.dpt = dpt self.initFunc = initFunc + self.syncCopies = syncCopies + -- This makes initial copy of models to GPUs synchronous. Set this option + -- in case your model serialization code is not thread-safe. end -function ThreadsImpl:applyChanges() +function ThreadsImpl:applyChanges(sync) if self.__threads then local module = self.dpt.modules[1] for i, gpu in ipairs(self.dpt.gpuAssignments) do @@ -697,6 +700,9 @@ function ThreadsImpl:applyChanges() _G.module = module:clone() end end) + if sync then + self.__threads:synchronize() + end -- if sync is set, changes are applied synchronously end self.__threads:synchronize() end @@ -711,7 +717,7 @@ function ThreadsImpl:setup() function() require 'cunn' end, self.initFunc) self.__threads:specific(true) - self:applyChanges() + self:applyChanges(self.syncCopies) end end