Skip to content

Commit fcfb982

Browse files
authored
Merge pull request #185 from labthings/handle-action-exceptions
Handle action exceptions If an HTTPException is thrown by an action's thread, before the request has been responded to, propagate the exception in the request handler so that it aborts the response with the right error code. This means marshmallow validation on action arguments now works.
2 parents ea56731 + 1c97565 commit fcfb982

File tree

9 files changed

+223
-35
lines changed

9 files changed

+223
-35
lines changed

src/labthings/actions/pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,21 @@ def start(self, thread: ActionThread):
2929
self.add(thread)
3030
thread.start()
3131

32-
def spawn(self, action: str, function, *args, **kwargs):
32+
def spawn(self, action: str, function, *args, http_error_lock=None, **kwargs):
3333
"""
3434
3535
:param function:
3636
:param *args:
3737
:param **kwargs:
3838
3939
"""
40-
thread = ActionThread(action, target=function, args=args, kwargs=kwargs)
40+
thread = ActionThread(
41+
action,
42+
target=function,
43+
http_error_lock=http_error_lock,
44+
args=args,
45+
kwargs=kwargs,
46+
)
4147
self.start(thread)
4248
return thread
4349

src/labthings/actions/thread.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Callable, Dict, Iterable, Optional
88

99
from flask import copy_current_request_context, has_request_context, request
10-
from werkzeug.exceptions import BadRequest
10+
from werkzeug.exceptions import BadRequest, HTTPException
1111

1212
from ..deque import LockableDeque
1313
from ..utilities import TimeoutTracker
@@ -22,6 +22,42 @@ class ActionKilledException(SystemExit):
2222
class ActionThread(threading.Thread):
2323
"""
2424
A native thread with extra functionality for tracking progress and thread termination.
25+
26+
Arguments:
27+
* `action` is the name of the action that's running
28+
* `target`, `name`, `args`, `kwargs` and `daemon` are passed to `threading.Thread`
29+
(though the defualt for `daemon` is changed to `True`)
30+
* `default_stop_timeout` specifies how long we wait for the `target` function to
31+
stop nicely (e.g. by checking the `stopping` Event )
32+
* `log_len` gives the number of log entries before we start dumping them
33+
* `http_error_lock` allows the calling thread to handle some
34+
errors initially. See below.
35+
36+
## Error propagation
37+
If the `target` function throws an Exception, by default this will result in:
38+
* The thread terminating
39+
* The Action's status being set to `error`
40+
* The exception appearing in the logs with a traceback
41+
* The exception being raised in the background thread.
42+
However, `HTTPException` subclasses are used in Flask/Werkzeug web apps to
43+
return HTTP status codes indicating specific errors, and so merit being
44+
handled differently.
45+
46+
Normally, when an Action is initiated, the thread handling the HTTP request
47+
does not return immediately - it waits for a short period to check whether
48+
the Action has completed or returned an error. If an HTTPError is raised
49+
in the Action thread before the initiating thread has sent an HTTP response,
50+
we **don't** want to propagate the error here, but instead want to re-raise
51+
it in the calling thread. This will then mean that the HTTP request is
52+
answered with the appropriate error code, rather than returning a `201`
53+
code, along with a description of the task (showing that it was successfully
54+
started, but also showing that it subsequently failed with an error).
55+
56+
In order to activate this behaviour, we must pass in a `threading.Lock`
57+
object. This lock should already be acquired by the request-handling
58+
thread. If an error occurs, and this lock is acquired, the exception
59+
should not be re-raised until the calling thread has had the chance to deal
60+
with it.
2561
"""
2662

2763
def __init__(
@@ -34,6 +70,7 @@ def __init__(
3470
daemon: bool = True,
3571
default_stop_timeout: int = 5,
3672
log_len: int = 100,
73+
http_error_lock: Optional[threading.Lock] = None,
3774
):
3875
threading.Thread.__init__(
3976
self,
@@ -56,6 +93,8 @@ def __init__(
5693
# Event to track if the user has requested stop
5794
self.stopping: threading.Event = threading.Event()
5895
self.default_stop_timeout: int = default_stop_timeout
96+
# Allow the calling thread to handle HTTP errors for a short time at the start
97+
self.http_error_lock = http_error_lock or threading.Lock()
5998

6099
# Make _target, _args, and _kwargs available to the subclass
61100
self._target: Optional[Callable] = target
@@ -85,6 +124,7 @@ def __init__(
85124
self._request_time: datetime.datetime = datetime.datetime.now()
86125
self._start_time: Optional[datetime.datetime] = None # Task start time
87126
self._end_time: Optional[datetime.datetime] = None # Task end time
127+
self._exception: Optional[Exception] = None # Propagate exceptions helpfully
88128

89129
# Public state properties
90130
self.progress: Optional[int] = None # Percent progress of the task
@@ -151,6 +191,11 @@ def cancelled(self) -> bool:
151191
"""Alias of `stopped`"""
152192
return self.stopped
153193

194+
@property
195+
def exception(self) -> Optional[Exception]:
196+
"""The Exception that caused the action to fail."""
197+
return self._exception
198+
154199
def update_progress(self, progress: int):
155200
"""
156201
Update the progress of the ActionThread.
@@ -214,15 +259,29 @@ def wrapped(*args, **kwargs):
214259
# Set state to stopped
215260
self._status = "cancelled"
216261
self.progress = None
262+
except HTTPException as e:
263+
self._exception = e
264+
# If the lock is acquired elsewhere, assume the error
265+
# will be handled there.
266+
if self.http_error_lock.acquire(blocking=False):
267+
self.http_error_lock.release()
268+
logging.error(
269+
"An HTTPException occurred in an action thread, but "
270+
"the parent request was no longer waiting for it."
271+
)
272+
logging.error(traceback.format_exc())
273+
raise e
217274
except Exception as e: # skipcq: PYL-W0703
275+
self._exception = e
218276
logging.error(traceback.format_exc())
219-
self._return_value = str(e)
220-
self._status = "error"
221277
raise e
222278
finally:
223279
self._end_time = datetime.datetime.now()
224280
logging.getLogger().removeHandler(handler) # Stop logging this thread
225281
# If we don't remove the handler, it's a memory leak.
282+
if self._exception:
283+
self._return_value = str(self._exception)
284+
self._status = "error"
226285

227286
return wrapped
228287

src/labthings/schema.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ def preprocess(self, data, **_):
8282

8383

8484
class ActionSchema(Schema):
85-
""" """
85+
"""Represents a running or completed Action
86+
87+
Actions can run in the background, started by one request
88+
and subsequently polled for updates. This schema represents
89+
one Action."""
8690

8791
action = fields.String()
8892
_ID = fields.String(data_key="id")

src/labthings/views/__init__.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import datetime
2+
import threading
23
from collections import OrderedDict
34
from typing import Callable, Dict, List, Optional, Set, cast
45

56
from flask import request
67
from flask.views import MethodView
78
from typing_extensions import Protocol
9+
from werkzeug.exceptions import HTTPException
810
from werkzeug.wrappers import Response as ResponseBase
911

1012
from ..actions.pool import Pool
@@ -215,25 +217,38 @@ def dispatch_request(self, *args, **kwargs):
215217
pool = (
216218
current_labthing().actions if current_labthing() else self._emergency_pool
217219
)
218-
# Make a task out of the views `post` method
219-
task = pool.spawn(self.endpoint, meth, *args, **kwargs)
220-
# Optionally override the threads default_stop_timeout
221-
if self.default_stop_timeout is not None:
222-
task.default_stop_timeout = self.default_stop_timeout
223-
224-
# Wait up to 2 second for the action to complete or error
225-
try:
226-
task.get(block=True, timeout=self.wait_for)
227-
except TimeoutError:
228-
pass
229-
230-
# Log the action to the view's deque
231-
self._deque.append(task)
220+
# We pass in this lock to tell the Action thread that we'll deal
221+
# with HTTP errors in this thread
222+
error_lock = threading.RLock()
223+
with error_lock:
224+
# Make a task out of the views `post` method
225+
task = pool.spawn(
226+
self.endpoint, meth, *args, http_error_lock=error_lock, **kwargs
227+
)
228+
# Optionally override the threads default_stop_timeout
229+
if self.default_stop_timeout is not None:
230+
task.default_stop_timeout = self.default_stop_timeout
231+
232+
# Log the action to the view's deque
233+
self._deque.append(task)
234+
235+
# Wait up to 2 second for the action to complete or error
236+
try:
237+
task.get(block=True, timeout=self.wait_for)
238+
except TimeoutError:
239+
pass
232240

233241
# If the action returns quickly, and returns a valid Response, return it as-is
234242
if task.output and isinstance(task.output, ResponseBase):
235243
return self.represent_response((task.output, 200))
236244

245+
# If the action fails quickly with an HTTPException, propagate it.
246+
# This allows us to handle validation errors nicely.
247+
# Similarly, calling Flask's `abort(404)` will work during the
248+
# timeout period, as it uses the same mechanism.
249+
if task.exception and isinstance(task.exception, HTTPException):
250+
raise task.exception
251+
237252
return self.represent_response((ActionSchema().dump(task), 201))
238253

239254

tests/conftest.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import os
3+
import time
34

45
import jsonschema
56
import pytest
67
from apispec import APISpec
78
from apispec.ext.marshmallow import MarshmallowPlugin
8-
from flask import Flask
9+
from flask import Flask, abort
910
from flask.testing import FlaskClient
1011
from flask.views import MethodView
1112
from marshmallow import validate
@@ -188,7 +189,7 @@ class TestAction(ActionView):
188189
def post(self):
189190
return "POST"
190191

191-
thing.add_view(TestAction, "TestAction")
192+
thing.add_view(TestAction, "/TestAction")
192193

193194
class TestProperty(PropertyView):
194195
schema = {"count": fields.Integer()}
@@ -199,7 +200,7 @@ def get(self):
199200
def post(self, args):
200201
pass
201202

202-
thing.add_view(TestProperty, "TestProperty")
203+
thing.add_view(TestProperty, "/TestProperty")
203204

204205
class TestFieldProperty(PropertyView):
205206
schema = fields.String(validate=validate.OneOf(["one", "two"]))
@@ -210,7 +211,35 @@ def get(self):
210211
def post(self, args):
211212
pass
212213

213-
thing.add_view(TestFieldProperty, "TestFieldProperty")
214+
thing.add_view(TestFieldProperty, "/TestFieldProperty")
215+
216+
class FailAction(ActionView):
217+
wait_for = 0.1
218+
219+
def post(self):
220+
raise Exception("This action is meant to fail with an Exception")
221+
222+
thing.add_view(FailAction, "/FailAction")
223+
224+
class AbortAction(ActionView):
225+
wait_for = 0.1
226+
args = {"abort_after": fields.Number()}
227+
228+
def post(self, args):
229+
if args.get("abort_after", 0) > 0:
230+
time.sleep(args["abort_after"])
231+
abort(418, "I'm a teapot! This action should abort with an HTTP code 418")
232+
233+
thing.add_view(AbortAction, "/AbortAction")
234+
235+
class ActionWithValidation(ActionView):
236+
wait_for = 0.1
237+
args = {"test_arg": fields.String(validate=validate.OneOf(["one", "two"]))}
238+
239+
def post(self, args):
240+
return True
241+
242+
thing.add_view(ActionWithValidation, "/ActionWithValidation")
214243

215244
return thing
216245

tests/test_action_api.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
import logging
3+
import time
4+
5+
import pytest
6+
7+
from labthings import LabThing
8+
from labthings.views import ActionView
9+
10+
11+
@pytest.mark.filterwarnings("ignore:Exception in thread")
12+
def test_action_exception_handling(thing_with_some_views, client):
13+
"""Check errors in an Action are handled correctly
14+
15+
16+
17+
`/FieldProperty` has a validation constraint - it
18+
should return a "bad response" error if invoked with
19+
anything other than
20+
"""
21+
# `/FailAction` raises an `Exception`.
22+
# This ought to return a 201 code representing the
23+
# action that was successfully started - but should
24+
# show that it failed through the "status" field.
25+
26+
# This is correct for the current (24/7/2021) behaviour
27+
# but may want to change for the next version, e.g.
28+
# returning a 500 code. For further discussion...
29+
r = client.post("/FailAction")
30+
assert r.status_code == 201
31+
action = r.get_json()
32+
assert action["status"] == "error"
33+
34+
35+
def test_action_abort(thing_with_some_views, client):
36+
"""Check HTTPExceptions result in error codes.
37+
38+
Subclasses of HTTPError should result in a non-200 return code, not
39+
just failures. This covers Marshmallow validation (400) and
40+
use of `abort()`.
41+
"""
42+
# `/AbortAction` should return a 418 error code
43+
r = client.post("/AbortAction")
44+
assert r.status_code == 418
45+
46+
47+
@pytest.mark.filterwarnings("ignore:Exception in thread")
48+
def test_action_abort_late(thing_with_some_views, client, caplog):
49+
"""Check HTTPExceptions raised late are just regular errors."""
50+
caplog.set_level(logging.ERROR)
51+
caplog.clear()
52+
r = client.post("/AbortAction", data=json.dumps({"abort_after": 0.2}))
53+
assert r.status_code == 201 # Should have started OK
54+
time.sleep(0.3)
55+
# Now check the status - should be error
56+
r2 = client.get(r.get_json()["links"]["self"]["href"])
57+
assert r2.get_json()["status"] == "error"
58+
# Check it was logged as well
59+
error_was_raised = False
60+
for r in caplog.records:
61+
if r.levelname == "ERROR" and "HTTPException" in r.message:
62+
error_was_raised = True
63+
assert error_was_raised
64+
65+
66+
def test_action_validate(thing_with_some_views, client):
67+
"""Validation errors should result in 422 return codes."""
68+
# `/ActionWithValidation` should fail with a 400 error
69+
# if `test_arg` is not either `one` or `two`
70+
r = client.post("/ActionWithValidation", data=json.dumps({"test_arg": "one"}))
71+
assert r.status_code in [200, 201]
72+
assert r.get_json()["status"] == "completed"
73+
r = client.post("/ActionWithValidation", data=json.dumps({"test_arg": "three"}))
74+
assert r.status_code in [422]

tests/test_labthing_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_blank_exception(app):
6565
e = Exception()
6666
e.message = None
6767

68-
# Test a 404 HTTPException
68+
# Test an empty Exception
6969
response = error_handler.std_handler(e)
7070

7171
response_json = json.dumps(response[0])

0 commit comments

Comments
 (0)