Skip to content

Commit 23422c7

Browse files
authored
Added result parsing on return. (#415)
1 parent f445296 commit 23422c7

File tree

7 files changed

+110
-6
lines changed

7 files changed

+110
-6
lines changed

taskiq/abc/broker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Optional,
1919
TypeVar,
2020
Union,
21+
get_type_hints,
2122
overload,
2223
)
2324
from uuid import uuid4
@@ -327,12 +328,18 @@ def inner(
327328
inner_task_name = f"{fmodule}:{fname}"
328329
wrapper = wraps(func)
329330

331+
sign = get_type_hints(func)
332+
return_type = None
333+
if "return" in sign:
334+
return_type = sign["return"]
335+
330336
decorated_task = wrapper(
331337
self.decorator_class(
332338
broker=self,
333339
original_func=func,
334340
labels=inner_labels,
335341
task_name=inner_task_name,
342+
return_type=return_type, # type: ignore
336343
),
337344
)
338345

taskiq/brokers/shared_broker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]:
3030
task_name=self.task_name,
3131
broker=broker,
3232
labels=self.labels,
33+
return_type=self.return_type,
3334
)
3435

3536

taskiq/compat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# flake8: noqa
22
from functools import lru_cache
3-
from typing import Any, Dict, Optional, Type, TypeVar, Union
3+
from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union
44

55
import pydantic
66
from importlib_metadata import version
@@ -12,13 +12,13 @@
1212
IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0")
1313

1414
if IS_PYDANTIC2:
15-
T = TypeVar("T")
15+
T = TypeVar("T", bound=Hashable)
1616

1717
@lru_cache()
18-
def create_type_adapter(annot: T) -> pydantic.TypeAdapter[T]:
18+
def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]:
1919
return pydantic.TypeAdapter(annot)
2020

21-
def parse_obj_as(annot: T, obj: Any) -> T:
21+
def parse_obj_as(annot: Type[T], obj: Any) -> T:
2222
return create_type_adapter(annot).validate_python(obj)
2323

2424
def model_validate(

taskiq/decor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
Callable,
99
Dict,
1010
Generic,
11+
Optional,
12+
Type,
1113
TypeVar,
1214
Union,
1315
overload,
@@ -51,11 +53,13 @@ def __init__(
5153
task_name: str,
5254
original_func: Callable[_FuncParams, _ReturnType],
5355
labels: Dict[str, Any],
56+
return_type: Optional[Type[_ReturnType]] = None,
5457
) -> None:
5558
self.broker = broker
5659
self.task_name = task_name
5760
self.original_func = original_func
5861
self.labels = labels
62+
self.return_type = return_type
5963

6064
# This is a hack to make ProcessPoolExecutor work
6165
# with decorated functions.
@@ -204,6 +208,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]:
204208
task_name=self.task_name,
205209
broker=self.broker,
206210
labels=self.labels,
211+
return_type=self.return_type,
207212
)
208213

209214
def __repr__(self) -> str:

taskiq/kicker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Dict,
1010
Generic,
1111
Optional,
12+
Type,
1213
TypeVar,
1314
Union,
1415
overload,
@@ -46,12 +47,14 @@ def __init__(
4647
task_name: str,
4748
broker: "AsyncBroker",
4849
labels: Dict[str, Any],
50+
return_type: Optional[Type[_ReturnType]] = None,
4951
) -> None:
5052
self.task_name = task_name
5153
self.broker = broker
5254
self.labels = labels
5355
self.custom_task_id: Optional[str] = None
5456
self.custom_schedule_id: Optional[str] = None
57+
self.return_type = return_type
5558

5659
def with_labels(
5760
self,
@@ -169,6 +172,7 @@ async def kiq(
169172
return AsyncTaskiqTask(
170173
task_id=message.task_id,
171174
result_backend=self.broker.result_backend,
175+
return_type=self.return_type, # type: ignore # (pyright issue)
172176
)
173177

174178
async def schedule_by_cron(

taskiq/task.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
2+
from logging import getLogger
23
from time import time
3-
from typing import TYPE_CHECKING, Any, Generic, Optional
4+
from typing import TYPE_CHECKING, Any, Generic, Optional, Type
45

56
from typing_extensions import TypeVar
67

8+
from taskiq.compat import parse_obj_as
79
from taskiq.exceptions import (
810
ResultGetError,
911
ResultIsReadyError,
@@ -15,6 +17,8 @@
1517
from taskiq.depends.progress_tracker import TaskProgress
1618
from taskiq.result import TaskiqResult
1719

20+
logger = getLogger("taskiq.task")
21+
1822
_ReturnType = TypeVar("_ReturnType")
1923

2024

@@ -25,9 +29,11 @@ def __init__(
2529
self,
2630
task_id: str,
2731
result_backend: "AsyncResultBackend[_ReturnType]",
32+
return_type: Optional[Type[_ReturnType]] = None,
2833
) -> None:
2934
self.task_id = task_id
3035
self.result_backend = result_backend
36+
self.return_type = return_type
3137

3238
async def is_ready(self) -> bool:
3339
"""
@@ -53,10 +59,19 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType
5359
:return: task's return value.
5460
"""
5561
try:
56-
return await self.result_backend.get_result(
62+
res = await self.result_backend.get_result(
5763
self.task_id,
5864
with_logs=with_logs,
5965
)
66+
if self.return_type is not None:
67+
try:
68+
res.return_value = parse_obj_as(
69+
self.return_type,
70+
res.return_value,
71+
)
72+
except ValueError:
73+
logger.warning("Cannot parse return type into %s", self.return_type)
74+
return res
6075
except Exception as exc:
6176
raise ResultGetError from exc
6277

tests/test_task.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import uuid
2+
from typing import Dict, TypeVar
3+
4+
import pytest
5+
from pydantic import BaseModel
6+
7+
from taskiq import serializers
8+
from taskiq.abc import AsyncResultBackend
9+
from taskiq.abc.serializer import TaskiqSerializer
10+
from taskiq.compat import model_dump, model_validate
11+
from taskiq.result.v1 import TaskiqResult
12+
from taskiq.task import AsyncTaskiqTask
13+
14+
_ReturnType = TypeVar("_ReturnType")
15+
16+
17+
class SerializingBackend(AsyncResultBackend[_ReturnType]):
18+
def __init__(self, serializer: TaskiqSerializer) -> None:
19+
self._serializer = serializer
20+
self._results: Dict[str, bytes] = {}
21+
22+
async def set_result(
23+
self,
24+
task_id: str,
25+
result: TaskiqResult[_ReturnType], # type: ignore
26+
) -> None:
27+
"""Set result with dumping."""
28+
self._results[task_id] = self._serializer.dumpb(model_dump(result))
29+
30+
async def is_result_ready(self, task_id: str) -> bool:
31+
"""Check if result is ready."""
32+
return task_id in self._results
33+
34+
async def get_result(
35+
self,
36+
task_id: str,
37+
with_logs: bool = False,
38+
) -> TaskiqResult[_ReturnType]:
39+
"""Get result with loading."""
40+
data = self._results[task_id]
41+
return model_validate(TaskiqResult, self._serializer.loadb(data))
42+
43+
44+
@pytest.mark.parametrize(
45+
"serializer",
46+
[
47+
serializers.MSGPackSerializer(),
48+
serializers.CBORSerializer(),
49+
serializers.PickleSerializer(),
50+
serializers.JSONSerializer(),
51+
],
52+
)
53+
@pytest.mark.anyio
54+
async def test_res_parsing_success(serializer: TaskiqSerializer) -> None:
55+
class MyResult(BaseModel):
56+
name: str
57+
age: int
58+
59+
res = MyResult(name="test", age=10)
60+
res_back: AsyncResultBackend[MyResult] = SerializingBackend(serializer)
61+
test_id = str(uuid.uuid4())
62+
await res_back.set_result(
63+
test_id,
64+
TaskiqResult(
65+
is_err=False,
66+
return_value=res,
67+
execution_time=0.0,
68+
),
69+
)
70+
sent_task = AsyncTaskiqTask(test_id, res_back, MyResult)
71+
parsed = await sent_task.wait_result()
72+
assert isinstance(parsed.return_value, MyResult)

0 commit comments

Comments
 (0)