88from aleph_message .models import ItemHash , ItemType , MessageType
99from configmanager import Config
1010from pydantic import ValidationError
11+ from sqlalchemy import select
1112from sqlalchemy .dialects .postgresql import insert
1213
1314from aleph .chains .signature_verifier import SignatureVerifier
3940from aleph .schemas .pending_messages import parse_message
4041from aleph .storage import StorageService
4142from aleph .toolkit .timestamp import timestamp_to_datetime
42- from aleph .types .db_session import DbSession , DbSessionFactory
43+ from aleph .types .db_session import AsyncDbSession , AsyncDbSessionFactory
4344from aleph .types .files import FileType
4445from aleph .types .message_processing_result import ProcessedMessage , RejectedMessage
4546from aleph .types .message_status import (
@@ -122,7 +123,7 @@ async def fetch_pending_message(
122123
123124 return validated_message
124125
125- async def fetch_related_content (self , session : DbSession , message : MessageDb ):
126+ async def fetch_related_content (self , session : AsyncDbSession , message : MessageDb ):
126127 content_handler = self .get_content_handler (message .type )
127128
128129 try :
@@ -135,7 +136,7 @@ async def fetch_related_content(self, session: DbSession, message: MessageDb):
135136 ) from e
136137
137138 async def load_fetched_content (
138- self , session : DbSession , pending_message : PendingMessageDb
139+ self , session : AsyncDbSession , pending_message : PendingMessageDb
139140 ) -> PendingMessageDb :
140141 if pending_message .item_type != ItemType .inline :
141142 pending_message .fetched = False
@@ -161,7 +162,7 @@ class MessagePublisher(BaseMessageHandler):
161162
162163 def __init__ (
163164 self ,
164- session_factory : DbSessionFactory ,
165+ session_factory : AsyncDbSessionFactory ,
165166 storage_service : StorageService ,
166167 config : Config ,
167168 pending_message_exchange : aio_pika .abc .AbstractExchange ,
@@ -191,21 +192,22 @@ async def add_pending_message(
191192 origin : Optional [MessageOrigin ] = MessageOrigin .P2P ,
192193 ) -> Optional [PendingMessageDb ]:
193194 # TODO: this implementation is just messy, improve it.
194- with self .session_factory () as session :
195+ async with self .session_factory () as session :
195196 try :
196197 # we don't check signatures yet.
197198 message = parse_message (message_dict )
198199 except InvalidMessageException as e :
199200 LOGGER .warning (e )
200- reject_new_pending_message (
201+ await reject_new_pending_message (
201202 session = session ,
202203 pending_message = message_dict ,
203204 exception = e ,
204205 tx_hash = tx_hash ,
205206 )
206- session .commit ()
207+ await session .commit ()
207208 return None
208209
210+ # TODO: fix this pydanticV2 issue
209211 pending_message = PendingMessageDb .from_obj (
210212 message ,
211213 reception_time = reception_time ,
@@ -220,25 +222,24 @@ async def add_pending_message(
220222 )
221223 except InvalidMessageException as e :
222224 LOGGER .warning ("Invalid message: %s - %s" , message .item_hash , str (e ))
223- reject_new_pending_message (
225+ await reject_new_pending_message (
224226 session = session ,
225227 pending_message = message_dict ,
226228 exception = e ,
227229 tx_hash = tx_hash ,
228230 )
229- session .commit ()
231+ await session .commit ()
230232 return None
231233
232234 # Check if there are an already existing record
233- existing_message = (
234- session .query (PendingMessageDb )
235- .filter_by (
236- sender = pending_message .sender ,
237- item_hash = pending_message .item_hash ,
238- signature = pending_message .signature ,
239- )
240- .one_or_none ()
235+ stmt = select (PendingMessageDb ).filter_by (
236+ sender = pending_message .sender ,
237+ item_hash = pending_message .item_hash ,
238+ signature = pending_message .signature ,
241239 )
240+
241+ existing_message = (await session .execute (stmt )).scalars ().one_or_none ()
242+
242243 if existing_message :
243244 return existing_message
244245
@@ -255,9 +256,9 @@ async def add_pending_message(
255256 )
256257
257258 try :
258- session .execute (upsert_message_status_stmt )
259- session .execute (insert_pending_message_stmt )
260- session .commit ()
259+ await session .execute (upsert_message_status_stmt )
260+ await session .execute (insert_pending_message_stmt )
261+ await session .commit ()
261262 except sqlalchemy .exc .IntegrityError :
262263 # Handle the unique constraint violation.
263264 LOGGER .warning ("Duplicate pending message detected trying to save it." )
@@ -269,14 +270,14 @@ async def add_pending_message(
269270 pending_message .item_hash ,
270271 str (e ),
271272 )
272- session .rollback ()
273- reject_new_pending_message (
273+ await session .rollback ()
274+ await reject_new_pending_message (
274275 session = session ,
275276 pending_message = message_dict ,
276277 exception = e ,
277278 tx_hash = tx_hash ,
278279 )
279- session .commit ()
280+ await session .commit ()
280281 return None
281282
282283 await self ._publish_pending_message (pending_message )
@@ -307,53 +308,56 @@ async def verify_signature(self, pending_message: PendingMessageDb):
307308
308309 @staticmethod
309310 async def confirm_existing_message (
310- session : DbSession ,
311+ session : AsyncDbSession ,
311312 existing_message : MessageDb ,
312313 pending_message : PendingMessageDb ,
313314 ):
314315 if pending_message .signature != existing_message .signature :
315316 raise InvalidSignature (f"Invalid signature for { pending_message .item_hash } " )
316317
317- delete_pending_message (session = session , pending_message = pending_message )
318+ await delete_pending_message (session = session , pending_message = pending_message )
318319 if tx_hash := pending_message .tx_hash :
319- session .execute (
320+ await session .execute (
320321 make_confirmation_upsert_query (
321322 item_hash = pending_message .item_hash , tx_hash = tx_hash
322323 )
323324 )
324325
325326 @staticmethod
326327 async def confirm_existing_forgotten_message (
327- session : DbSession ,
328+ session : AsyncDbSession ,
328329 forgotten_message : ForgottenMessageDb ,
329330 pending_message : PendingMessageDb ,
330331 ):
331332 if pending_message .signature != forgotten_message .signature :
332333 raise InvalidSignature (f"Invalid signature for { pending_message .item_hash } " )
333334
334- delete_pending_message (session = session , pending_message = pending_message )
335+ await delete_pending_message (session = session , pending_message = pending_message )
335336
336337 async def insert_message (
337- self , session : DbSession , pending_message : PendingMessageDb , message : MessageDb
338+ self ,
339+ session : AsyncDbSession ,
340+ pending_message : PendingMessageDb ,
341+ message : MessageDb ,
338342 ):
339- session .execute (make_message_upsert_query (message ))
343+ await session .execute (make_message_upsert_query (message ))
340344 if message .item_type != ItemType .inline :
341- upsert_file (
345+ await upsert_file (
342346 session = session ,
343347 file_hash = message .item_hash ,
344348 size = message .size ,
345349 file_type = FileType .FILE ,
346350 )
347- insert_content_file_pin (
351+ await insert_content_file_pin (
348352 session = session ,
349353 file_hash = message .item_hash ,
350354 owner = message .sender ,
351355 item_hash = message .item_hash ,
352356 created = timestamp_to_datetime (message .content ["time" ]),
353357 )
354358
355- delete_pending_message (session = session , pending_message = pending_message )
356- session .execute (
359+ await delete_pending_message (session = session , pending_message = pending_message )
360+ await session .execute (
357361 make_message_status_upsert_query (
358362 item_hash = message .item_hash ,
359363 new_status = MessageStatus .PROCESSED ,
@@ -363,18 +367,18 @@ async def insert_message(
363367 )
364368
365369 if tx_hash := pending_message .tx_hash :
366- session .execute (
370+ await session .execute (
367371 make_confirmation_upsert_query (
368372 item_hash = message .item_hash , tx_hash = tx_hash
369373 )
370374 )
371375
372376 async def insert_costs (
373- self , session : DbSession , costs : List [AccountCostsDb ], message : MessageDb
377+ self , session : AsyncDbSession , costs : List [AccountCostsDb ], message : MessageDb
374378 ):
375379 if len (costs ) > 0 :
376380 insert_stmt = make_costs_upsert_query (costs )
377- session .execute (insert_stmt )
381+ await session .execute (insert_stmt )
378382
379383 async def verify_message (self , pending_message : PendingMessageDb ) -> MessageDb :
380384 await self .verify_signature (pending_message = pending_message )
@@ -385,7 +389,7 @@ async def verify_message(self, pending_message: PendingMessageDb) -> MessageDb:
385389 return validated_message
386390
387391 async def process (
388- self , session : DbSession , pending_message : PendingMessageDb
392+ self , session : AsyncDbSession , pending_message : PendingMessageDb
389393 ) -> ProcessedMessage | RejectedMessage :
390394 """
391395 Process a pending message.
@@ -401,7 +405,7 @@ async def process(
401405 """
402406
403407 # Note: Check if message already exists (and confirm it)
404- existing_message = get_message_by_item_hash (
408+ existing_message = await get_message_by_item_hash (
405409 session = session , item_hash = ItemHash (pending_message .item_hash )
406410 )
407411 if existing_message :
@@ -414,7 +418,7 @@ async def process(
414418
415419 # Note: Check if message is already forgotten (and confirm it)
416420 # this is to avoid race conditions when a confirmation arrives after the FORGET message has been preocessed
417- forgotten_message = get_forgotten_message (
421+ forgotten_message = await get_forgotten_message (
418422 session = session , item_hash = ItemHash (pending_message .item_hash )
419423 )
420424 if forgotten_message :
0 commit comments