Skip to content

Commit 0cf4e30

Browse files
committed
fix: TaskiqAdminMiddleware work with dataclasses
1 parent 02f338a commit 0cf4e30

File tree

7 files changed

+197
-109
lines changed

7 files changed

+197
-109
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ dev = [
6868
"freezegun>=1.5.5",
6969
"tzdata>=2025.2; sys_platform == 'win32'",
7070
"opentelemetry-test-utils (>=0.59b0,<1)",
71+
"polyfactory>=3.1.0",
7172
]
7273

7374
[project.urls]
@@ -172,8 +173,8 @@ lint.ignore = [
172173
"PLR0913", # Too many arguments for function call
173174
"D106", # Missing docstring in public nested class
174175
]
175-
exclude = [".venv/"]
176176
lint.mccabe = { max-complexity = 10 }
177+
exclude = [".venv/"]
177178
line-length = 88
178179

179180
[tool.ruff.lint.per-file-ignores]

taskiq/middlewares/taskiq_admin_middleware.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import aiohttp
88

99
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.compat import model_dump
1011
from taskiq.message import TaskiqMessage
1112
from taskiq.result import TaskiqResult
1213

@@ -115,12 +116,13 @@ async def post_send(self, message: TaskiqMessage) -> None:
115116
116117
:param message: kicked message.
117118
"""
119+
dict_message: dict[str, Any] = model_dump(message)
118120
await self._spawn_request(
119121
f"/api/tasks/{message.task_id}/queued",
120122
{
121-
"args": message.args,
122-
"kwargs": message.kwargs,
123-
"labels": message.labels,
123+
"args": dict_message["args"],
124+
"kwargs": dict_message["kwargs"],
125+
"labels": dict_message["labels"],
124126
"queuedAt": self._now_iso(),
125127
"taskName": message.task_name,
126128
"worker": self.__ta_broker_name,
@@ -137,12 +139,13 @@ async def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
137139
:param message: incoming parsed taskiq message.
138140
:return: modified message.
139141
"""
142+
dict_message: dict[str, Any] = model_dump(message)
140143
await self._spawn_request(
141144
f"/api/tasks/{message.task_id}/started",
142145
{
143-
"args": message.args,
144-
"kwargs": message.kwargs,
145-
"labels": message.labels,
146+
"args": dict_message["args"],
147+
"kwargs": dict_message["kwargs"],
148+
"labels": dict_message["labels"],
146149
"startedAt": self._now_iso(),
147150
"taskName": message.task_name,
148151
"worker": self.__ta_broker_name,
@@ -164,12 +167,13 @@ async def post_execute(
164167
:param message: incoming message.
165168
:param result: result of execution for current task.
166169
"""
170+
dict_result: dict[str, Any] = model_dump(result)
167171
await self._spawn_request(
168172
f"/api/tasks/{message.task_id}/executed",
169173
{
170174
"finishedAt": self._now_iso(),
171175
"executionTime": result.execution_time,
172176
"error": None if result.error is None else repr(result.error),
173-
"returnValue": {"return_value": result.return_value},
177+
"returnValue": {"return_value": dict_result["return_value"]},
174178
},
175179
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from aiohttp import web
3+
from aiohttp.test_utils import TestServer
4+
from typing_extensions import AsyncGenerator
5+
6+
from taskiq.brokers.inmemory_broker import InMemoryBroker
7+
from taskiq.brokers.shared_broker import async_shared_broker
8+
from taskiq.middlewares import TaskiqAdminMiddleware
9+
from tests.middlewares.admin_middleware.dto import DataclassDTO, PydanticDTO, TypedDictDTO
10+
11+
12+
@pytest.fixture(scope="session")
13+
async def admin_api_server() -> AsyncGenerator[TestServer, None]:
14+
"""Создает тестовый HTTP сервер, который всегда отвечает 200."""
15+
16+
async def handle_queued(request: web.Request) -> web.Response:
17+
return web.json_response({"status": "ok"}, status=200)
18+
19+
async def handle_started(request: web.Request) -> web.Response:
20+
return web.json_response({"status": "ok"}, status=200)
21+
22+
async def handle_executed(request: web.Request) -> web.Response:
23+
return web.json_response({"status": "ok"}, status=200)
24+
25+
# Создаем приложение
26+
app = web.Application()
27+
app.router.add_post("/api/tasks/{task_id}/queued", handle_queued)
28+
app.router.add_post("/api/tasks/{task_id}/started", handle_started)
29+
app.router.add_post("/api/tasks/{task_id}/executed", handle_executed)
30+
31+
# Создаем и запускаем тестовый сервер
32+
server = TestServer(app)
33+
await server.start_server()
34+
35+
yield server
36+
37+
# Останавливаем сервер после теста
38+
await server.close()
39+
40+
41+
@pytest.fixture
42+
async def broker(admin_api_server: TestServer) -> AsyncGenerator[InMemoryBroker, None]:
43+
broker = InMemoryBroker().with_middlewares(
44+
TaskiqAdminMiddleware(
45+
str(admin_api_server.make_url("/")), # URL тестового сервера
46+
"supersecret",
47+
taskiq_broker_name="InMemory",
48+
),
49+
)
50+
51+
broker.register_task(task_with_dataclass, task_name="task_with_dataclass")
52+
broker.register_task(task_with_typed_dict, task_name="task_with_typed_dict")
53+
broker.register_task(task_with_pydantic_model, task_name="task_with_pydantic_model")
54+
async_shared_broker.default_broker(broker)
55+
56+
await broker.startup()
57+
yield broker
58+
await broker.shutdown()
59+
60+
61+
async def task_with_dataclass(dto: DataclassDTO) -> None:
62+
assert dto
63+
64+
65+
async def task_with_typed_dict(dto: TypedDictDTO) -> None:
66+
assert dto
67+
68+
69+
async def task_with_pydantic_model(dto: PydanticDTO) -> None:
70+
assert dto
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from dataclasses import dataclass
2+
from typing import TypedDict
3+
4+
import pydantic
5+
6+
7+
@dataclass(frozen=True, slots=True)
8+
class DataclassNestedDTO:
9+
id: int
10+
name: str
11+
12+
@dataclass(frozen=True, slots=True)
13+
class DataclassDTO:
14+
nested: DataclassNestedDTO
15+
recipients: list[str]
16+
subject: str
17+
attachments: list[str] | None = None
18+
text: str | None = None
19+
html: str | None = None
20+
21+
22+
class PydanticDTO(pydantic.BaseModel):
23+
number: int
24+
text: str
25+
flag: bool
26+
list: list[float]
27+
dictionary: dict[str, str] | None = None
28+
29+
30+
class TypedDictDTO(TypedDict):
31+
id: int
32+
name: str
33+
active: bool
34+
scores: list[int]
35+
metadata: dict[str, str] | None
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import asyncio
2+
3+
import pytest
4+
from polyfactory.factories import BaseFactory, DataclassFactory, TypedDictFactory
5+
from polyfactory.factories.pydantic_factory import ModelFactory
6+
7+
from taskiq.brokers.inmemory_broker import InMemoryBroker
8+
from taskiq.decor import AsyncTaskiqDecoratedTask
9+
from tests.middlewares.admin_middleware.dto import DataclassDTO, PydanticDTO, TypedDictDTO
10+
11+
12+
class DataclassDTOFactory(DataclassFactory[DataclassDTO]):
13+
__model__ = DataclassDTO
14+
15+
16+
class TypedDictDTOFactory(TypedDictFactory[TypedDictDTO]):
17+
__model__ = TypedDictDTO
18+
19+
20+
class PydanticDTOFactory(ModelFactory[PydanticDTO]):
21+
__model__ = PydanticDTO
22+
23+
24+
class TestArgumentsFormattingInAdminMiddleware:
25+
@pytest.mark.parametrize(
26+
"dto_factory, task_name",
27+
[
28+
pytest.param(DataclassDTOFactory, "task_with_dataclass", id="dataclass"),
29+
pytest.param(TypedDictDTOFactory, "task_with_typed_dict", id="typeddict"),
30+
pytest.param(PydanticDTOFactory, "task_with_pydantic_model", id="pydantic"),
31+
],
32+
)
33+
async def test_when_task_dto_passed__then_middleware_succesfully_send_request(
34+
self,
35+
broker: InMemoryBroker,
36+
dto_factory: type[BaseFactory],
37+
task_name: str,
38+
) -> None:
39+
# given
40+
task_arguments = dto_factory.build()
41+
task: AsyncTaskiqDecoratedTask = broker.find_task(task_name)
42+
assert task is not None, f"Task {task_name} should be registered in the broker"
43+
# when
44+
kicked_task = await task.kiq(task_arguments)
45+
await asyncio.sleep(1)
46+
# then
47+
result = await kicked_task.get_result()
48+
assert result.error is None # we just expect no errors during post_send/pre_execute/post_execute

tests/middlewares/test_taskiq_admin_middleware.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

uv.lock

Lines changed: 31 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)