Skip to content

Commit 06030dc

Browse files
committed
Refactor: use async sessions factory instead of sync
1 parent ffed9d1 commit 06030dc

File tree

5 files changed

+73
-63
lines changed

5 files changed

+73
-63
lines changed

src/aleph/api_entrypoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aleph.config
99
from aleph.chains.signature_verifier import SignatureVerifier
10-
from aleph.db.connection import make_engine, make_session_factory
10+
from aleph.db.connection import make_async_engine, make_async_session_factory
1111
from aleph.services.cache.node_cache import NodeCache
1212
from aleph.services.ipfs import IpfsService
1313
from aleph.services.p2p import init_p2p_client
@@ -34,12 +34,12 @@ async def configure_aiohttp_app(
3434
with sentry_sdk.start_transaction(name="init-api-server"):
3535
p2p_client = await init_p2p_client(config, service_name="api-server-aiohttp")
3636

37-
engine = make_engine(
37+
engine = make_async_engine(
3838
config,
3939
echo=config.logging.level.value == logging.DEBUG,
4040
application_name="aleph-api",
4141
)
42-
session_factory = make_session_factory(engine)
42+
session_factory = make_async_session_factory(engine)
4343

4444
node_cache = NodeCache(
4545
redis_host=config.redis.host.value, redis_port=config.redis.port.value

src/aleph/chains/nuls2.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from aleph.schemas.chains.tx_context import TxContext
3030
from aleph.schemas.pending_messages import BasePendingMessage
3131
from aleph.toolkit.timestamp import utc_now
32-
from aleph.types.db_session import DbSessionFactory
32+
from aleph.types.db_session import AsyncDbSessionFactory
3333
from aleph.utils import run_in_executor
3434

3535
from ..db.models import ChainTxDb
@@ -80,7 +80,7 @@ async def verify_signature(self, message: BasePendingMessage) -> bool:
8080
class Nuls2Connector(ChainWriter):
8181
def __init__(
8282
self,
83-
session_factory: DbSessionFactory,
83+
session_factory: AsyncDbSessionFactory,
8484
pending_tx_publisher: PendingTxPublisher,
8585
chain_data_service: ChainDataService,
8686
):
@@ -90,8 +90,8 @@ def __init__(
9090

9191
async def get_last_height(self, sync_type: ChainEventType) -> int:
9292
"""Returns the last height for which we already have the nuls data."""
93-
with self.session_factory() as session:
94-
last_height = get_last_height(
93+
async with self.session_factory() as session:
94+
last_height = await get_last_height(
9595
session=session, chain=Chain.NULS2, sync_type=sync_type
9696
)
9797

@@ -133,15 +133,15 @@ async def _request_transactions(
133133
LOGGER.info("Incoming logic data is not JSON, ignoring. %r" % ldata)
134134

135135
if last_height:
136-
with self.session_factory() as session:
137-
upsert_chain_sync_status(
136+
async with self.session_factory() as session:
137+
await upsert_chain_sync_status(
138138
session=session,
139139
chain=Chain.NULS2,
140140
sync_type=ChainEventType.SYNC,
141141
height=last_height,
142142
update_datetime=utc_now(),
143143
)
144-
session.commit()
144+
await session.commit()
145145

146146
async def fetcher(self, config: Config):
147147
last_stored_height = await self.get_last_height(sync_type=ChainEventType.SYNC)
@@ -158,11 +158,11 @@ async def fetcher(self, config: Config):
158158
tx = ChainTxDb.from_sync_tx_context(
159159
tx_context=context, tx_data=jdata
160160
)
161-
with self.session_factory() as db_session:
161+
async with self.session_factory() as db_session:
162162
await self.pending_tx_publisher.add_and_publish_pending_tx(
163163
session=db_session, tx=tx
164164
)
165-
db_session.commit()
165+
await db_session.commit()
166166

167167
await asyncio.sleep(10)
168168

@@ -182,9 +182,9 @@ async def packer(self, config: Config):
182182
nonce = await get_nonce(server, address, chain_id)
183183

184184
while True:
185-
with self.session_factory() as session:
186-
if (count_pending_txs(session=session, chain=Chain.NULS2)) or (
187-
count_pending_messages(session=session, chain=Chain.NULS2)
185+
async with self.session_factory() as session:
186+
if (await count_pending_txs(session=session, chain=Chain.NULS2)) or (
187+
await count_pending_messages(session=session, chain=Chain.NULS2)
188188
):
189189
await asyncio.sleep(30)
190190
continue
@@ -195,7 +195,7 @@ async def packer(self, config: Config):
195195
i = 0
196196

197197
messages = list(
198-
get_unconfirmed_messages(
198+
await get_unconfirmed_messages(
199199
session=session, limit=10000, chain=Chain.ETH
200200
)
201201
)
@@ -208,7 +208,7 @@ async def packer(self, config: Config):
208208
)
209209
)
210210
# Required to apply update to the files table in get_chaindata
211-
session.commit()
211+
await session.commit()
212212

213213
content = sync_event_payload.json()
214214
tx = await prepare_transfer_tx(

src/aleph/commands.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from aleph.chains.chain_data_service import ChainDataService, PendingTxPublisher
2727
from aleph.chains.connector import ChainConnector
2828
from aleph.cli.args import parse_args
29-
from aleph.db.connection import make_db_url, make_engine, make_session_factory
29+
from aleph.db.connection import (
30+
make_async_engine,
31+
make_async_session_factory,
32+
make_db_url,
33+
)
3034
from aleph.exceptions import InvalidConfigException, KeyNotFoundException
3135
from aleph.jobs import start_jobs
3236
from aleph.jobs.cron.balance_job import BalanceCronJob
@@ -124,12 +128,12 @@ async def main(args: List[str]) -> None:
124128
run_db_migrations(config)
125129
LOGGER.info("Database initialized.")
126130

127-
engine = make_engine(
131+
engine = make_async_engine(
128132
config,
129133
echo=args.loglevel == logging.DEBUG,
130134
application_name="aleph-conn-manager",
131135
)
132-
session_factory = make_session_factory(engine)
136+
session_factory = make_async_session_factory(engine)
133137

134138
setup_logging(args.loglevel)
135139

src/aleph/handlers/message_handler.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aleph_message.models import ItemHash, ItemType, MessageType
99
from configmanager import Config
1010
from pydantic import ValidationError
11+
from sqlalchemy import select
1112
from sqlalchemy.dialects.postgresql import insert
1213

1314
from aleph.chains.signature_verifier import SignatureVerifier
@@ -39,7 +40,7 @@
3940
from aleph.schemas.pending_messages import parse_message
4041
from aleph.storage import StorageService
4142
from aleph.toolkit.timestamp import timestamp_to_datetime
42-
from aleph.types.db_session import DbSession, DbSessionFactory
43+
from aleph.types.db_session import AsyncDbSession, AsyncDbSessionFactory
4344
from aleph.types.files import FileType
4445
from aleph.types.message_processing_result import ProcessedMessage, RejectedMessage
4546
from aleph.types.message_status import (
@@ -122,7 +123,7 @@ async def fetch_pending_message(
122123

123124
return validated_message
124125

125-
async def fetch_related_content(self, session: DbSession, message: MessageDb):
126+
async def fetch_related_content(self, session: AsyncDbSession, message: MessageDb):
126127
content_handler = self.get_content_handler(message.type)
127128

128129
try:
@@ -135,7 +136,7 @@ async def fetch_related_content(self, session: DbSession, message: MessageDb):
135136
) from e
136137

137138
async def load_fetched_content(
138-
self, session: DbSession, pending_message: PendingMessageDb
139+
self, session: AsyncDbSession, pending_message: PendingMessageDb
139140
) -> PendingMessageDb:
140141
if pending_message.item_type != ItemType.inline:
141142
pending_message.fetched = False
@@ -161,7 +162,7 @@ class MessagePublisher(BaseMessageHandler):
161162

162163
def __init__(
163164
self,
164-
session_factory: DbSessionFactory,
165+
session_factory: AsyncDbSessionFactory,
165166
storage_service: StorageService,
166167
config: Config,
167168
pending_message_exchange: aio_pika.abc.AbstractExchange,
@@ -191,21 +192,22 @@ async def add_pending_message(
191192
origin: Optional[MessageOrigin] = MessageOrigin.P2P,
192193
) -> Optional[PendingMessageDb]:
193194
# TODO: this implementation is just messy, improve it.
194-
with self.session_factory() as session:
195+
async with self.session_factory() as session:
195196
try:
196197
# we don't check signatures yet.
197198
message = parse_message(message_dict)
198199
except InvalidMessageException as e:
199200
LOGGER.warning(e)
200-
reject_new_pending_message(
201+
await reject_new_pending_message(
201202
session=session,
202203
pending_message=message_dict,
203204
exception=e,
204205
tx_hash=tx_hash,
205206
)
206-
session.commit()
207+
await session.commit()
207208
return None
208209

210+
# TODO: fix this pydanticV2 issue
209211
pending_message = PendingMessageDb.from_obj(
210212
message,
211213
reception_time=reception_time,
@@ -220,25 +222,24 @@ async def add_pending_message(
220222
)
221223
except InvalidMessageException as e:
222224
LOGGER.warning("Invalid message: %s - %s", message.item_hash, str(e))
223-
reject_new_pending_message(
225+
await reject_new_pending_message(
224226
session=session,
225227
pending_message=message_dict,
226228
exception=e,
227229
tx_hash=tx_hash,
228230
)
229-
session.commit()
231+
await session.commit()
230232
return None
231233

232234
# Check if there are an already existing record
233-
existing_message = (
234-
session.query(PendingMessageDb)
235-
.filter_by(
236-
sender=pending_message.sender,
237-
item_hash=pending_message.item_hash,
238-
signature=pending_message.signature,
239-
)
240-
.one_or_none()
235+
stmt = select(PendingMessageDb).filter_by(
236+
sender=pending_message.sender,
237+
item_hash=pending_message.item_hash,
238+
signature=pending_message.signature,
241239
)
240+
241+
existing_message = (await session.execute(stmt)).scalars().one_or_none()
242+
242243
if existing_message:
243244
return existing_message
244245

@@ -255,9 +256,9 @@ async def add_pending_message(
255256
)
256257

257258
try:
258-
session.execute(upsert_message_status_stmt)
259-
session.execute(insert_pending_message_stmt)
260-
session.commit()
259+
await session.execute(upsert_message_status_stmt)
260+
await session.execute(insert_pending_message_stmt)
261+
await session.commit()
261262
except sqlalchemy.exc.IntegrityError:
262263
# Handle the unique constraint violation.
263264
LOGGER.warning("Duplicate pending message detected trying to save it.")
@@ -269,14 +270,14 @@ async def add_pending_message(
269270
pending_message.item_hash,
270271
str(e),
271272
)
272-
session.rollback()
273-
reject_new_pending_message(
273+
await session.rollback()
274+
await reject_new_pending_message(
274275
session=session,
275276
pending_message=message_dict,
276277
exception=e,
277278
tx_hash=tx_hash,
278279
)
279-
session.commit()
280+
await session.commit()
280281
return None
281282

282283
await self._publish_pending_message(pending_message)
@@ -307,53 +308,56 @@ async def verify_signature(self, pending_message: PendingMessageDb):
307308

308309
@staticmethod
309310
async def confirm_existing_message(
310-
session: DbSession,
311+
session: AsyncDbSession,
311312
existing_message: MessageDb,
312313
pending_message: PendingMessageDb,
313314
):
314315
if pending_message.signature != existing_message.signature:
315316
raise InvalidSignature(f"Invalid signature for {pending_message.item_hash}")
316317

317-
delete_pending_message(session=session, pending_message=pending_message)
318+
await delete_pending_message(session=session, pending_message=pending_message)
318319
if tx_hash := pending_message.tx_hash:
319-
session.execute(
320+
await session.execute(
320321
make_confirmation_upsert_query(
321322
item_hash=pending_message.item_hash, tx_hash=tx_hash
322323
)
323324
)
324325

325326
@staticmethod
326327
async def confirm_existing_forgotten_message(
327-
session: DbSession,
328+
session: AsyncDbSession,
328329
forgotten_message: ForgottenMessageDb,
329330
pending_message: PendingMessageDb,
330331
):
331332
if pending_message.signature != forgotten_message.signature:
332333
raise InvalidSignature(f"Invalid signature for {pending_message.item_hash}")
333334

334-
delete_pending_message(session=session, pending_message=pending_message)
335+
await delete_pending_message(session=session, pending_message=pending_message)
335336

336337
async def insert_message(
337-
self, session: DbSession, pending_message: PendingMessageDb, message: MessageDb
338+
self,
339+
session: AsyncDbSession,
340+
pending_message: PendingMessageDb,
341+
message: MessageDb,
338342
):
339-
session.execute(make_message_upsert_query(message))
343+
await session.execute(make_message_upsert_query(message))
340344
if message.item_type != ItemType.inline:
341-
upsert_file(
345+
await upsert_file(
342346
session=session,
343347
file_hash=message.item_hash,
344348
size=message.size,
345349
file_type=FileType.FILE,
346350
)
347-
insert_content_file_pin(
351+
await insert_content_file_pin(
348352
session=session,
349353
file_hash=message.item_hash,
350354
owner=message.sender,
351355
item_hash=message.item_hash,
352356
created=timestamp_to_datetime(message.content["time"]),
353357
)
354358

355-
delete_pending_message(session=session, pending_message=pending_message)
356-
session.execute(
359+
await delete_pending_message(session=session, pending_message=pending_message)
360+
await session.execute(
357361
make_message_status_upsert_query(
358362
item_hash=message.item_hash,
359363
new_status=MessageStatus.PROCESSED,
@@ -363,18 +367,18 @@ async def insert_message(
363367
)
364368

365369
if tx_hash := pending_message.tx_hash:
366-
session.execute(
370+
await session.execute(
367371
make_confirmation_upsert_query(
368372
item_hash=message.item_hash, tx_hash=tx_hash
369373
)
370374
)
371375

372376
async def insert_costs(
373-
self, session: DbSession, costs: List[AccountCostsDb], message: MessageDb
377+
self, session: AsyncDbSession, costs: List[AccountCostsDb], message: MessageDb
374378
):
375379
if len(costs) > 0:
376380
insert_stmt = make_costs_upsert_query(costs)
377-
session.execute(insert_stmt)
381+
await session.execute(insert_stmt)
378382

379383
async def verify_message(self, pending_message: PendingMessageDb) -> MessageDb:
380384
await self.verify_signature(pending_message=pending_message)
@@ -385,7 +389,7 @@ async def verify_message(self, pending_message: PendingMessageDb) -> MessageDb:
385389
return validated_message
386390

387391
async def process(
388-
self, session: DbSession, pending_message: PendingMessageDb
392+
self, session: AsyncDbSession, pending_message: PendingMessageDb
389393
) -> ProcessedMessage | RejectedMessage:
390394
"""
391395
Process a pending message.
@@ -401,7 +405,7 @@ async def process(
401405
"""
402406

403407
# Note: Check if message already exists (and confirm it)
404-
existing_message = get_message_by_item_hash(
408+
existing_message = await get_message_by_item_hash(
405409
session=session, item_hash=ItemHash(pending_message.item_hash)
406410
)
407411
if existing_message:
@@ -414,7 +418,7 @@ async def process(
414418

415419
# Note: Check if message is already forgotten (and confirm it)
416420
# this is to avoid race conditions when a confirmation arrives after the FORGET message has been preocessed
417-
forgotten_message = get_forgotten_message(
421+
forgotten_message = await get_forgotten_message(
418422
session=session, item_hash=ItemHash(pending_message.item_hash)
419423
)
420424
if forgotten_message:

0 commit comments

Comments
 (0)