Skip to content

Commit e4f74e3

Browse files
authored
Merge pull request #12 from taskiq-python/add-schedule-source
feat: add broker and schedule source
2 parents 4ab19a8 + 4658e89 commit e4f74e3

File tree

11 files changed

+914
-21
lines changed

11 files changed

+914
-21
lines changed

.github/workflows/test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ jobs:
4242
- name: Set up PostgreSQL
4343
uses: ikalnytskyi/action-setup-postgres@v8
4444
with:
45-
username: postgres
46-
password: postgres
47-
database: taskiqpsqlpy
45+
username: taskiq_psqlpy
46+
password: look_in_vault
47+
database: taskiq_psqlpy
4848
id: postgres
4949
- name: Set up uv and enable cache
5050
id: setup-uv

docker-compose.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
services:
2+
postgres:
3+
container_name: taskiq_psqlpy
4+
image: postgres:18
5+
environment:
6+
POSTGRES_DB: taskiq_psqlpy
7+
POSTGRES_USER: taskiq_psqlpy
8+
POSTGRES_PASSWORD: look_in_vault
9+
ports:
10+
- "5432:5432"

pyproject.toml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ dev = [
4242
{include-group = "lint"},
4343
{include-group = "test"},
4444
"pre-commit>=4.5.0",
45-
"anyio>=4.12.0",
4645
]
4746
test = [
4847
"pytest>=9.0.1",
4948
"pytest-cov>=7.0.0",
5049
"pytest-env>=1.2.0",
5150
"pytest-xdist>=3.8.0",
51+
"pytest-asyncio>=1.3.0",
52+
"polyfactory>=3.1.0",
53+
"sqlalchemy-utils>=0.42.1",
5254
]
5355
lint = [
5456
"black>=25.11.0",
@@ -80,7 +82,7 @@ module-root = ""
8082
module-name = "taskiq_psqlpy"
8183

8284
[tool.ruff]
83-
line-length = 88
85+
line-length = 120
8486

8587
[tool.ruff.lint]
8688
# List of enabled rulsets.
@@ -147,3 +149,14 @@ allow-magic-value-types = ["int", "str", "float"]
147149

148150
[tool.ruff.lint.flake8-bugbear]
149151
extend-immutable-calls = ["taskiq_dependencies.Depends", "taskiq.TaskiqDepends"]
152+
153+
[tool.pytest.ini_options]
154+
pythonpath = [
155+
"."
156+
]
157+
asyncio_mode = "auto"
158+
asyncio_default_fixture_loop_scope = "function"
159+
markers = [
160+
"unit: marks unit tests",
161+
"integration: marks tests with real infrastructure env",
162+
]

taskiq_psqlpy/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from taskiq_psqlpy.broker import PSQLPyBroker
12
from taskiq_psqlpy.result_backend import PSQLPyResultBackend
3+
from taskiq_psqlpy.schedule_source import PSQLPyScheduleSource
24

35
__all__ = [
6+
"PSQLPyBroker",
47
"PSQLPyResultBackend",
8+
"PSQLPyScheduleSource",
59
]

taskiq_psqlpy/broker.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import asyncio
2+
import logging
3+
import typing as tp
4+
from collections.abc import AsyncGenerator
5+
from dataclasses import dataclass
6+
from datetime import datetime
7+
8+
import psqlpy
9+
from psqlpy.exceptions import ConnectionExecuteError
10+
from psqlpy.extra_types import JSONB
11+
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage
12+
13+
from taskiq_psqlpy.queries import (
14+
CLAIM_MESSAGE_QUERY,
15+
CREATE_MESSAGE_TABLE_QUERY,
16+
DELETE_MESSAGE_QUERY,
17+
INSERT_MESSAGE_QUERY,
18+
)
19+
20+
logger = logging.getLogger("taskiq.psqlpy_broker")
21+
_T = tp.TypeVar("_T")
22+
23+
24+
@dataclass
25+
class MessageRow:
26+
"""Message in db table."""
27+
28+
id: int
29+
task_id: str
30+
task_name: str
31+
message: str
32+
labels: JSONB
33+
status: str
34+
created_at: datetime
35+
36+
37+
class PSQLPyBroker(AsyncBroker):
38+
"""Broker that uses PostgreSQL and PSQLPy with LISTEN/NOTIFY."""
39+
40+
_read_conn: psqlpy.Connection
41+
_write_pool: psqlpy.ConnectionPool
42+
_listener: psqlpy.Listener
43+
44+
def __init__(
45+
self,
46+
dsn: (
47+
str | tp.Callable[[], str]
48+
) = "postgresql://taskiq_psqlpy:look_in_vault@localhost:5432/taskiq_psqlpy",
49+
result_backend: AsyncResultBackend[_T] | None = None,
50+
task_id_generator: tp.Callable[[], str] | None = None,
51+
channel_name: str = "taskiq",
52+
table_name: str = "taskiq_messages",
53+
max_retry_attempts: int = 5,
54+
read_kwargs: dict[str, tp.Any] | None = None,
55+
write_kwargs: dict[str, tp.Any] | None = None,
56+
) -> None:
57+
"""
58+
Construct a new broker.
59+
60+
Args:
61+
dsn: connection string to PostgreSQL, or callable returning one.
62+
result_backend: Custom result backend.
63+
task_id_generator: Custom task_id generator.
64+
channel_name: Name of the channel to listen on.
65+
table_name: Name of the table to store messages.
66+
max_retry_attempts: Maximum number of message processing attempts.
67+
read_kwargs: Additional arguments for read connection creation.
68+
write_kwargs: Additional arguments for write pool creation.
69+
70+
"""
71+
super().__init__(
72+
result_backend=result_backend,
73+
task_id_generator=task_id_generator,
74+
)
75+
self._dsn: str | tp.Callable[[], str] = dsn
76+
self.channel_name: str = channel_name
77+
self.table_name: str = table_name
78+
self.read_kwargs: dict[str, tp.Any] = read_kwargs or {}
79+
self.write_kwargs: dict[str, tp.Any] = write_kwargs or {}
80+
self.max_retry_attempts: int = max_retry_attempts
81+
self._queue: asyncio.Queue[str] | None = None
82+
83+
@property
84+
def dsn(self) -> str:
85+
"""
86+
Get the DSN string.
87+
88+
Returns:
89+
A string with dsn or None if dsn isn't set yet.
90+
91+
"""
92+
if callable(self._dsn):
93+
return self._dsn()
94+
return self._dsn
95+
96+
async def startup(self) -> None:
97+
"""Initialize the broker."""
98+
await super().startup()
99+
self._read_conn = await psqlpy.connect(
100+
dsn=self.dsn,
101+
**self.read_kwargs,
102+
)
103+
self._write_pool = psqlpy.ConnectionPool(
104+
dsn=self.dsn,
105+
**self.write_kwargs,
106+
)
107+
108+
# create messages table if it doesn't exist
109+
async with self._write_pool.acquire() as conn:
110+
await conn.execute(CREATE_MESSAGE_TABLE_QUERY.format(self.table_name))
111+
112+
# listen to notification channel
113+
self._listener = self._write_pool.listener()
114+
await self._listener.add_callback(self.channel_name, self._notification_handler)
115+
await self._listener.startup()
116+
self._listener.listen()
117+
118+
self._queue = asyncio.Queue()
119+
120+
async def shutdown(self) -> None:
121+
"""Close all connections on shutdown."""
122+
await super().shutdown()
123+
if self._read_conn is not None:
124+
self._read_conn.close()
125+
if self._write_pool is not None:
126+
self._write_pool.close()
127+
if self._listener is not None:
128+
self._listener.abort_listen()
129+
await self._listener.shutdown()
130+
131+
async def _notification_handler(
132+
self,
133+
connection: psqlpy.Connection,
134+
payload: str,
135+
channel: str,
136+
process_id: int,
137+
) -> None:
138+
"""
139+
Handle NOTIFY messages.
140+
141+
https://psqlpy-python.github.io/components/listener.html#usage
142+
"""
143+
logger.debug("Received notification on channel %s: %s", channel, payload)
144+
if self._queue is not None:
145+
self._queue.put_nowait(payload)
146+
147+
async def kick(self, message: BrokerMessage) -> None:
148+
"""
149+
Send message to the channel.
150+
151+
Inserts the message into the database and sends a NOTIFY.
152+
153+
:param message: Message to send.
154+
"""
155+
async with self._write_pool.acquire() as conn:
156+
# insert message into db table
157+
message_inserted_id = tp.cast(
158+
"int",
159+
await conn.fetch_val(
160+
INSERT_MESSAGE_QUERY.format(self.table_name),
161+
[
162+
message.task_id,
163+
message.task_name,
164+
message.message.decode(),
165+
JSONB(message.labels),
166+
],
167+
),
168+
)
169+
170+
delay_value = tp.cast("str | None", message.labels.get("delay"))
171+
if delay_value is not None:
172+
delay_seconds = int(delay_value)
173+
asyncio.create_task( # noqa: RUF006
174+
self._schedule_notification(message_inserted_id, delay_seconds),
175+
)
176+
else:
177+
# Send NOTIFY with message ID as payload
178+
_ = await conn.execute(
179+
f"NOTIFY {self.channel_name}, '{message_inserted_id}'",
180+
)
181+
182+
async def _schedule_notification(self, message_id: int, delay_seconds: int) -> None:
183+
"""Schedule a notification to be sent after a delay."""
184+
await asyncio.sleep(delay_seconds)
185+
async with self._write_pool.acquire() as conn:
186+
# Send NOTIFY with message ID as payload
187+
_ = await conn.execute(f"NOTIFY {self.channel_name}, '{message_id}'")
188+
189+
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
190+
"""
191+
Listen to the channel.
192+
193+
Yields messages as they are received.
194+
195+
:yields: AckableMessage instances.
196+
"""
197+
while True:
198+
try:
199+
payload = await self._queue.get() # type: ignore[union-attr]
200+
message_id = int(payload) # payload is the message id
201+
try:
202+
async with self._write_pool.acquire() as conn:
203+
claimed_message = await conn.fetch_row(
204+
CLAIM_MESSAGE_QUERY.format(self.table_name),
205+
[message_id],
206+
)
207+
except ConnectionExecuteError: # message was claimed by another worker
208+
continue
209+
message_row_result = tp.cast(
210+
"MessageRow",
211+
tp.cast("object", claimed_message.as_class(MessageRow)),
212+
)
213+
message_data = message_row_result.message.encode()
214+
215+
async def ack(*, _message_id: int = message_id) -> None:
216+
async with self._write_pool.acquire() as conn:
217+
_ = await conn.execute(
218+
DELETE_MESSAGE_QUERY.format(self.table_name),
219+
[_message_id],
220+
)
221+
222+
yield AckableMessage(data=message_data, ack=ack)
223+
except Exception:
224+
logger.exception("Error processing message")
225+
continue

taskiq_psqlpy/queries.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,57 @@
2929
DELETE_RESULT_QUERY = """
3030
DELETE FROM {} WHERE task_id = $1
3131
"""
32+
33+
CREATE_SCHEDULES_TABLE_QUERY = """
34+
CREATE TABLE IF NOT EXISTS {} (
35+
id UUID PRIMARY KEY,
36+
task_name VARCHAR(100) NOT NULL,
37+
schedule JSONB NOT NULL,
38+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
39+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
40+
);
41+
"""
42+
43+
INSERT_SCHEDULE_QUERY = """
44+
INSERT INTO {} (id, task_name, schedule)
45+
VALUES ($1, $2, $3)
46+
ON CONFLICT (id) DO UPDATE
47+
SET task_name = EXCLUDED.task_name,
48+
schedule = EXCLUDED.schedule,
49+
updated_at = NOW();
50+
"""
51+
52+
SELECT_SCHEDULES_QUERY = """
53+
SELECT id, task_name, schedule
54+
FROM {};
55+
"""
56+
57+
DELETE_ALL_SCHEDULES_QUERY = """
58+
DELETE FROM {};
59+
"""
60+
61+
DELETE_SCHEDULE_QUERY = """
62+
DELETE FROM {} WHERE id = $1;
63+
"""
64+
65+
CREATE_MESSAGE_TABLE_QUERY = """
66+
CREATE TABLE IF NOT EXISTS {} (
67+
id SERIAL PRIMARY KEY,
68+
task_id VARCHAR NOT NULL,
69+
task_name VARCHAR NOT NULL,
70+
message TEXT NOT NULL,
71+
labels JSONB NOT NULL,
72+
status TEXT NOT NULL DEFAULT 'pending',
73+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
74+
);
75+
"""
76+
77+
INSERT_MESSAGE_QUERY = """
78+
INSERT INTO {} (task_id, task_name, message, labels)
79+
VALUES ($1, $2, $3, $4)
80+
RETURNING id
81+
"""
82+
83+
CLAIM_MESSAGE_QUERY = "UPDATE {} SET status = 'processing' WHERE id = $1 AND status = 'pending' RETURNING *"
84+
85+
DELETE_MESSAGE_QUERY = "DELETE FROM {} WHERE id = $1"

taskiq_psqlpy/result_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class PSQLPyResultBackend(AsyncResultBackend[_ReturnType]):
3232

3333
def __init__(
3434
self,
35-
dsn: str | None = "postgres://postgres:postgres@localhost:5432/postgres",
35+
dsn: (
36+
str | None
37+
) = "postgresql://taskiq_psqlpy:look_in_vault@localhost:5432/taskiq_psqlpy",
3638
keep_results: bool = True,
3739
table_name: str = "taskiq_results",
3840
field_for_task_id: Literal["VarChar", "Text"] = "VarChar",

0 commit comments

Comments
 (0)