@@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
108
108
"""
109
109
110
110
frame_annotations_type : ClassVar [Type [SqlFrameAnnotation ]] = SqlFrameAnnotation
111
+ sequence_annotations_type : ClassVar [Type [SqlSequenceAnnotation ]] = (
112
+ SqlSequenceAnnotation
113
+ )
111
114
112
115
sqlite_metadata_file : str = ""
113
116
dataset_root : Optional [str ] = None
@@ -246,8 +249,8 @@ def _get_item(
246
249
self .frame_annotations_type .frame_number
247
250
== int (frame ), # cast from np.int64
248
251
)
249
- seq_stmt = sa .select (SqlSequenceAnnotation ).where (
250
- SqlSequenceAnnotation .sequence_name == seq
252
+ seq_stmt = sa .select (self . sequence_annotations_type ).where (
253
+ self . sequence_annotations_type .sequence_name == seq
251
254
)
252
255
with Session (self ._sql_engine ) as session :
253
256
entry = session .scalars (stmt ).one ()
@@ -273,9 +276,10 @@ def sequence_names(self) -> Iterable[str]:
273
276
# override
274
277
def category_to_sequence_names (self ) -> Dict [str , List [str ]]:
275
278
stmt = sa .select (
276
- SqlSequenceAnnotation .category , SqlSequenceAnnotation .sequence_name
279
+ self .sequence_annotations_type .category ,
280
+ self .sequence_annotations_type .sequence_name ,
277
281
).where ( # we limit results to sequences that have frames after all filters
278
- SqlSequenceAnnotation .sequence_name .in_ (self .sequence_names ())
282
+ self . sequence_annotations_type .sequence_name .in_ (self .sequence_names ())
279
283
)
280
284
with self ._sql_engine .connect () as connection :
281
285
cat_to_seqs = pd .read_sql (stmt , connection )
@@ -414,14 +418,14 @@ def add_where(stmt):
414
418
return stmt .where (* where_conditions ) if where_conditions else stmt
415
419
416
420
if self .limit_sequences_per_category_to <= 0 :
417
- stmt = add_where (sa .select (SqlSequenceAnnotation .sequence_name ))
421
+ stmt = add_where (sa .select (self . sequence_annotations_type .sequence_name ))
418
422
else :
419
423
subquery = sa .select (
420
- SqlSequenceAnnotation .sequence_name ,
424
+ self . sequence_annotations_type .sequence_name ,
421
425
sa .func .row_number ()
422
426
.over (
423
427
order_by = sa .text ("ROWID" ), # NOTE: ROWID is SQLite-specific
424
- partition_by = SqlSequenceAnnotation .category ,
428
+ partition_by = self . sequence_annotations_type .category ,
425
429
)
426
430
.label ("row_number" ),
427
431
)
@@ -457,21 +461,23 @@ def _get_category_filters(self) -> List[sa.ColumnElement]:
457
461
return []
458
462
459
463
logger .info (f"Limiting dataset to categories: { self .pick_categories } " )
460
- return [SqlSequenceAnnotation .category .in_ (self .pick_categories )]
464
+ return [self . sequence_annotations_type .category .in_ (self .pick_categories )]
461
465
462
466
def _get_pick_filters (self ) -> List [sa .ColumnElement ]:
463
467
if not self .pick_sequences :
464
468
return []
465
469
466
470
logger .info (f"Limiting dataset to sequences: { self .pick_sequences } " )
467
- return [SqlSequenceAnnotation .sequence_name .in_ (self .pick_sequences )]
471
+ return [self . sequence_annotations_type .sequence_name .in_ (self .pick_sequences )]
468
472
469
473
def _get_exclude_filters (self ) -> List [sa .ColumnOperators ]:
470
474
if not self .exclude_sequences :
471
475
return []
472
476
473
477
logger .info (f"Removing sequences from the dataset: { self .exclude_sequences } " )
474
- return [SqlSequenceAnnotation .sequence_name .notin_ (self .exclude_sequences )]
478
+ return [
479
+ self .sequence_annotations_type .sequence_name .notin_ (self .exclude_sequences )
480
+ ]
475
481
476
482
def _load_subsets_from_json (self , subset_lists_path : str ) -> pd .DataFrame :
477
483
subsets = self .subsets
0 commit comments