-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_server.lua
230 lines (182 loc) · 8.1 KB
/
demo_server.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
------------------------------------------------------------------------
-- demo_server.lua
--
-- This is the example of a class that is used to implement a sever in
-- server.lua. This class has an _init(opt) function that takes in
-- the global parameters, loads in the data and builds the model on
-- the parameter server. The class also has a run() function that
-- forks out the child clients and executes the function 'worker'
-- on each corresponding client.
--
-- If you wish to develop your own SGD model, create a new class that is
-- similar to this.
------------------------------------------------------------------------
local demo_server = torch.class('demo_server')
------------
-- Worker code
------------
function worker()
-- Used to check files
require "lfs"
-- Used to update path
require 'package'
-- Alert successfully started up
parallel.print('Im a worker, my ID is: ', parallel.id, ' and my IP: ', parallel.ip)
-- Global indicating is a child
ischild = true
-- Extension to lua-lua folder from home directory. Set to no extension as default
ext = ""
-- Number of packages received
local n_pkg = 0
while true do
-- Allow the parent to terminate the child
m = parallel.yield()
if m == 'break' then break end
-- Receive data
local pkg = parallel.parent:receive()
-- Make sure to clean everything up since big files are being passed
io.write('.') io.flush()
collectgarbage()
if n_pkg == 0 then
-- This is the first time receiving a package, it has the globals
-- Receive and parse global parameters
parallel.print('Recieved initialization parameters')
cmd, arg, ext = pkg.cmd, pkg.arg, pkg.ext
opt = cmd:parse(arg)
-- Update path
package.path = opt.add_to_path .. package.path
-- Add in additional necessary parameters
opt.print = parallel.print
opt.parallel = true
-- Library used to handle data types
local data_loc = ext .. 'End-To-End-Generative-Dialogue/src/data'
if not lfs.attributes(data_loc .. '.lua') then
print('The file data.lua could not be found in ' .. data_loc .. '.lua')
os.exit()
end
data = require(data_loc)
-- Load in helper functions for this model defined in End-To-End-Generative-Dialogue
local model_funcs_loc = ext .. "End-To-End-Generative-Dialogue/src/model_functions.lua"
if not lfs.attributes(model_funcs_loc) then
print('The file model_functions.lua could not be found in ' .. model_funcs_loc)
os.exit()
end
funcs = loadfile(model_funcs_loc)
funcs()
-- Change the locations of the datafiles based on new extension
opt.data_file = ext .. opt.data_file
opt.val_data_file = ext .. opt.val_data_file
--point the wordvec to the right place if exists
if opt.pre_word_vecs ~= "" then
opt.pre_word_vecs = opt.extension .. opt.pre_word_vecs
end
-- Load in data to client
train_data, valid_data, opt = load_data(opt)
-- Build the model on the client
model, criterion = build()
-- send some data back
parallel.parent:send('Received parameters and loaded data successfully')
else
parallel.print('received params from batch with index: ', pkg.index)
-- Load in the parameters sent from the parent
for i = 1, #model.params do
model.params[i]:copy(pkg.parameters[i])
end
-- Training the model at the given index
local pkg_o = train_ind(pkg.index, model, criterion, train_data)
-- send some data back
parallel.print('sending back derivative for batch with index: ', pkg.index)
parallel.parent:send(pkg_o)
end
n_pkg = n_pkg + 1
end
end
------------
-- Server class
------------
-- Initialization function for the server object. Here we load in the data, build our
-- model, and then add any remote client objects if necessary.
function demo_server:__init(opt)
-- Save the command line options
self.opt = opt
-- Used to check files
require "lfs"
-- Library used to handle data types
local data_loc = 'End-To-End-Generative-Dialogue/src/data'
if not lfs.attributes(data_loc .. '.lua') then
print('The file data.lua could not be found in ' .. data_loc .. '.lua')
os.exit()
end
data = require(data_loc)
-- Load in helper functions for this model defined in End-To-End-Generative-Dialogue
local model_funcs_loc = "End-To-End-Generative-Dialogue/src/model_functions.lua"
if not lfs.attributes(model_funcs_loc) then
print('The file model_functions.lua could not be found in ' .. model_funcs_loc)
os.exit()
end
funcs = loadfile(model_funcs_loc)
funcs()
-- Load in the data
self:load_data()
-- Setup and build the model
self:build()
-- Add remote computers if necessary
if self.opt.remote then
parallel.print('Runnings clients remotely')
-- Open the list of client ip addresses
local fh,err = io.open("../client_list.txt")
if err then print("../client_list.txt not found"); return; end
-- line by line
while true do
local line = fh:read()
if line == nil then break end
local addr = self.opt.username .. '@' .. line
addr = string.gsub(addr, "\n", "") -- remove line breaks
-- Add the remote server by ip address
parallel.addremote( {ip=addr, cores=4, lua=self.opt.torch_path, protocol='ssh -ttq -o "StrictHostKeyChecking no" -i ~/.ssh/dist-sgd-sshkey'})
parallel.print('Adding address ', addr)
end
elseif opt.localhost then
-- Has remote clients launched through localhost
parallel.print('Running clients through localhost')
parallel.addremote({ip='localhost', cores=4, lua=self.opt.torch_path, protocol='ssh -o "StrictHostKeyChecking no" -i ~/.ssh/dist-sgd-sshkey'})
end
end
-- Main function that runs the server. Here the child clients are forked off and
-- the code in the 'worker' function is sent to the clients to be run. Once
-- the connection is established, :send() and :recieve() are used to pass
-- parameters between the client and the server
function demo_server:run()
parallel.print('Forking ', self.opt.n_proc, ' processes')
parallel.sfork(self.opt.n_proc)
parallel.print('Forked')
-- exec worker code in each process
parallel.children:exec(worker)
parallel.print('Finished telling workers to execute')
--send the global parameters to the children
parallel.children:join()
parallel.print('Sending parameters to children')
parallel.children:send({cmd = cmd, arg = arg, ext = self.opt.extension})
-- Get the responses from the children
replies = parallel.children:receive()
parallel.print('Replies from children', replies)
-- Train the model
train(self.model, self.criterion, self.train_data, self.valid_data)
parallel.print('Finished training the model')
-- sync/terminate when all workers are done
parallel.children:join('break')
parallel.print('All processes terminated')
end
-- Function loads in the training and validation data into self.train_data and
-- seld.valid_data.
function demo_server:load_data()
-- Simply calls the load_data function defined in "End-To-End-Generative-Dialogue/src/model_functions.lua"
self.train_data, self.valid_data, self.opt = load_data(self.opt)
end
-- Function loads in the nn model and criterion into self.model and self.criterion
function demo_server:build()
-- Simply calls the build function defined in "End-To-End-Generative-Dialogue/src/model_functions.lua"
self.model, self.criterion = build()
end
-- Return the server
return demo_server