Skip to content

Commit ba5c1d1

Browse files
committed
Fix cancellation of trainings
If a user performed a training -> cancel -> training -> cancel, the last cancel did not terminate properly and the API blocked, as the process did not join. Since we are using a process pool with only 1 slot, we can get the running processes in the pool and kill them.
1 parent da76307 commit ba5c1d1

File tree

2 files changed

+39
-27
lines changed

2 files changed

+39
-27
lines changed

deepaas/model/v2/wrapper.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import multiprocessing
2323
import multiprocessing.pool
2424
import os
25+
import signal
2526
import tempfile
2627

2728
from aiohttp import web
@@ -80,6 +81,8 @@ def __init__(self, name, model_obj, app):
8081
self._train_workers = CONF.train_workers
8182
self._train_executor = self._init_train_executor()
8283

84+
self._setup_cleanup()
85+
8386
schema = getattr(self.model_obj, "schema", None)
8487

8588
if isinstance(schema, dict):
@@ -109,6 +112,13 @@ def __init__(self, name, model_obj, app):
109112

110113
self.response_schema = schema
111114

115+
def _setup_cleanup(self):
116+
self._app.on_cleanup.append(self._close_executors)
117+
118+
async def _close_executors(self, app):
119+
self._train_executor.shutdown()
120+
self._predict_executor.shutdown()
121+
112122
def _init_predict_executor(self):
113123
n = self._predict_workers
114124
executor = concurrent.futures.ThreadPoolExecutor(max_workers=n)
@@ -119,22 +129,6 @@ def _init_train_executor(self):
119129
executor = CancellablePool(max_workers=n)
120130
return executor
121131

122-
# run = sconcurrent.futures.elf.loop.run_in_executor
123-
124-
# fs = [run(executor, self.warm, path) for i in range(0, n)]
125-
# await asyncio.gather(*fs)
126-
#
127-
async def close_executor():
128-
self._executor.shutdown()
129-
130-
# async def close_executor():
131-
# fs = [run(executor, self.clean) for i in range(0, n)]
132-
# await asyncio.shield(asyncio.gather(*fs))
133-
# executor.shutdown(wait=True)
134-
135-
self._app.on_cleanup.append(close_executor)
136-
# app['executor'] = executor
137-
138132
@contextlib.contextmanager
139133
def _catch_error(self):
140134
name = self.name
@@ -408,6 +402,13 @@ def _on_err(err):
408402
try:
409403
return await fut
410404
except asyncio.CancelledError:
405+
# This is ugly, but since our pools only have one slot we can
406+
# kill the process before termination
407+
try:
408+
pool._pool[0].kill()
409+
except AttributeError:
410+
os.kill(pool._pool[0].pid,
411+
signal.SIGKILL)
411412
pool.terminate()
412413
usable_pool = self._new_pool()
413414
finally:
@@ -416,6 +417,6 @@ def _on_err(err):
416417
self._change.set()
417418

418419
def shutdown(self):
419-
for p in self._working | self._free:
420+
for p in self._working:
420421
p.terminate()
421422
self._free.clear()

deepaas/tests/test_v2_models.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def get_train_args(self):
6969
self.assertRaises(NotImplementedError, m.train)
7070
self.assertRaises(NotImplementedError, m.get_train_args)
7171

72-
def test_bad_schema(self):
72+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
73+
def test_bad_schema(self, m_clean):
7374
class Model(object):
7475
schema = []
7576

@@ -80,7 +81,8 @@ class Model(object):
8081
self.app
8182
)
8283

83-
def test_validate_no_schema(self):
84+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
85+
def test_validate_no_schema(self, m_clean):
8486
class Model(object):
8587
schema = None
8688

@@ -91,7 +93,8 @@ class Model(object):
9193
None
9294
)
9395

94-
def test_invalid_schema(self):
96+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
97+
def test_invalid_schema(self, m_clean):
9598
class Model(object):
9699
schema = object()
97100

@@ -102,7 +105,8 @@ class Model(object):
102105
self.app
103106
)
104107

105-
def test_marshmallow_schema(self):
108+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
109+
def test_marshmallow_schema(self, m_clean):
106110
class Schema(marshmallow.Schema):
107111
foo = m_fields.Str()
108112

@@ -118,7 +122,8 @@ class Model(object):
118122
{"foo": 1.0}
119123
)
120124

121-
def test_dict_schema(self):
125+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
126+
def test_dict_schema(self, m_clean):
122127
class Model(object):
123128
schema = {
124129
"foo": m_fields.Str()
@@ -151,8 +156,9 @@ def test_dummy_model(self):
151156
for arg, val in itertools.chain(pargs.items(), targs.items()):
152157
self.assertIsInstance(val, fields.Field)
153158

159+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
154160
@test_utils.unittest_run_loop
155-
async def test_dummy_model_with_wrapper(self):
161+
async def test_dummy_model_with_wrapper(self, m_clean):
156162
w = v2_wrapper.ModelWrapper("foo", v2_test.TestModel(), self.app)
157163
task = w.predict()
158164
await task
@@ -175,8 +181,10 @@ async def test_dummy_model_with_wrapper(self):
175181
for arg, val in itertools.chain(pargs.items(), targs.items()):
176182
self.assertIsInstance(val, fields.Field)
177183

184+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
178185
@test_utils.unittest_run_loop
179-
async def test_model_with_not_implemented_attributes_and_wrapper(self):
186+
async def test_model_with_not_implemented_attributes_and_wrapper(self,
187+
m_clean):
180188
w = v2_wrapper.ModelWrapper("foo", object(), self.app)
181189

182190
# NOTE(aloga): Cannot use assertRaises here directly, as testtools
@@ -203,25 +211,28 @@ async def test_model_with_not_implemented_attributes_and_wrapper(self):
203211
for arg, val in itertools.chain(pargs.items(), targs.items()):
204212
self.assertIsInstance(val, fields.Field)
205213

214+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
206215
@mock.patch('deepaas.model.loading.get_available_models')
207-
def test_loading_ok(self, mock_loading):
216+
def test_loading_ok(self, mock_loading, m_clean):
208217
mock_loading.return_value = {uuid.uuid4().hex: "bar"}
209218
deepaas.model.v2.register_models(self.app)
210219
mock_loading.assert_called()
211220
for m in deepaas.model.v2.MODELS.values():
212221
self.assertIsInstance(m, v2_wrapper.ModelWrapper)
213222

223+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
214224
@mock.patch('deepaas.model.loading.get_available_models')
215-
def test_loading_ok_singleton(self, mock_loading):
225+
def test_loading_ok_singleton(self, mock_loading, m_clean):
216226
mock_loading.return_value = {uuid.uuid4().hex: "bar"}
217227
deepaas.model.v2.register_models(self.app)
218228
deepaas.model.v2.register_models(self.app)
219229
mock_loading.assert_called_once()
220230
for m in deepaas.model.v2.MODELS.values():
221231
self.assertIsInstance(m, v2_wrapper.ModelWrapper)
222232

233+
@mock.patch("deepaas.model.v2.wrapper.ModelWrapper._setup_cleanup")
223234
@mock.patch('deepaas.model.loading.get_available_models')
224-
def test_loading_error(self, mock_loading):
235+
def test_loading_error(self, mock_loading, m_clean):
225236
mock_loading.return_value = {}
226237
deepaas.model.v2.register_models(self.app)
227238
mock_loading.assert_called()

0 commit comments

Comments
 (0)