|
1 | 1 | from __future__ import print_function |
2 | 2 |
|
3 | | -from itertools import chain |
4 | | -import multiprocessing |
5 | | -import os |
6 | | -import signal |
7 | 3 | import socket |
8 | | -import sys |
9 | | -import traceback |
10 | 4 | import unittest |
11 | 5 |
|
12 | 6 | import six |
@@ -213,70 +207,19 @@ def test_socket_error(self): |
213 | 207 |
|
214 | 208 | def test_exception_handling(self): |
215 | 209 | """Tests closing socket when custom exception raised""" |
216 | | - queue = multiprocessing.Queue() |
217 | | - process = multiprocessing.Process(target=worker, args=(self.mc, queue)) |
218 | | - process.start() |
219 | | - if queue.get() != 'loop started': |
220 | | - raise ValueError( |
221 | | - 'Expected "loop started" message from the child process' |
222 | | - ) |
| 210 | + class CustomException(Exception): |
| 211 | + pass |
223 | 212 |
|
224 | | - # maximum test duration is 0.5 second |
225 | | - num_iters = 50 |
226 | | - timeout = 0.01 |
227 | | - for i in range(num_iters): |
228 | | - os.kill(process.pid, signal.SIGUSR1) |
| 213 | + self.mc.set('error', 1) |
| 214 | + with patch.object(self.mc, '_recv_value', |
| 215 | + Mock(side_effect=CustomException('custom error'))): |
229 | 216 | try: |
230 | | - exc = WorkerError(*queue.get(timeout=timeout)) |
231 | | - raise exc |
232 | | - except six.moves.queue.Empty: |
| 217 | + self.mc.get('error') |
| 218 | + except CustomException: |
233 | 219 | pass |
234 | | - if not process.is_alive(): |
235 | | - break |
236 | | - |
237 | | - if process.is_alive(): |
238 | | - os.kill(process.pid, signal.SIGTERM) |
239 | | - process.join() |
240 | | - |
241 | | - |
242 | | -class SignalException(Exception): |
243 | | - pass |
244 | | - |
245 | | - |
246 | | -def sighandler(signum, frame): |
247 | | - raise SignalException() |
248 | | - |
249 | | - |
250 | | -class WorkerError(Exception): |
251 | | - def __init__(self, exc, assert_tb, signal_tb=None): |
252 | | - super(WorkerError, self).__init__( |
253 | | - ''.join(chain(assert_tb, signal_tb or [])) |
254 | | - ) |
255 | | - self.cause = exc |
256 | | - |
257 | | - |
258 | | -def worker(mc, queue): |
259 | | - signal.signal(signal.SIGUSR1, sighandler) |
260 | | - |
261 | | - signal_tb = None |
262 | | - for i in range(100000): |
263 | | - if i == 0: |
264 | | - queue.put('loop started') |
265 | | - try: |
266 | | - k = str(i) |
267 | | - mc.set(k, i) |
268 | | - # This loop is just to increase chance to get previous value |
269 | | - # for clarity |
270 | | - for j in range(10): |
271 | | - mc.get(str(i-1)) |
272 | | - res = mc.get(k) |
273 | | - assert res == i, 'Expected {} but was {}'.format(i, res) |
274 | | - except AssertionError as e: |
275 | | - assert_tb = traceback.format_exception(*sys.exc_info()) |
276 | | - queue.put((e, assert_tb, signal_tb)) |
277 | | - break |
278 | | - except SignalException as e: |
279 | | - signal_tb = traceback.format_exception(*sys.exc_info()) |
| 220 | + self.assertIs(self.mc.servers[0].socket, None) |
| 221 | + self.assertEqual(self.mc.set('error', 2), True) |
| 222 | + self.assertEqual(self.mc.get('error'), 2) |
280 | 223 |
|
281 | 224 |
|
282 | 225 | if __name__ == '__main__': |
|
0 commit comments