Skip to content

Commit 42a4a7d

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Generalising SqlIndexDataset to support subtypes of SqlSequenceAnnotation
Summary: We did not often extend sequence-level metadata but now for applications like text-to-3D/video, we need to store captions and similar. Reviewed By: bottler Differential Revision: D68269926 fbshipit-source-id: f8af308adce51863d719a335d85cd2558943bd4c
1 parent 699bc67 commit 42a4a7d

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

pytorch3d/implicitron/dataset/sql_dataset.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
108108
"""
109109

110110
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
111+
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
112+
SqlSequenceAnnotation
113+
)
111114

112115
sqlite_metadata_file: str = ""
113116
dataset_root: Optional[str] = None
@@ -246,8 +249,8 @@ def _get_item(
246249
self.frame_annotations_type.frame_number
247250
== int(frame), # cast from np.int64
248251
)
249-
seq_stmt = sa.select(SqlSequenceAnnotation).where(
250-
SqlSequenceAnnotation.sequence_name == seq
252+
seq_stmt = sa.select(self.sequence_annotations_type).where(
253+
self.sequence_annotations_type.sequence_name == seq
251254
)
252255
with Session(self._sql_engine) as session:
253256
entry = session.scalars(stmt).one()
@@ -273,9 +276,10 @@ def sequence_names(self) -> Iterable[str]:
273276
# override
274277
def category_to_sequence_names(self) -> Dict[str, List[str]]:
275278
stmt = sa.select(
276-
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
279+
self.sequence_annotations_type.category,
280+
self.sequence_annotations_type.sequence_name,
277281
).where( # we limit results to sequences that have frames after all filters
278-
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
282+
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
279283
)
280284
with self._sql_engine.connect() as connection:
281285
cat_to_seqs = pd.read_sql(stmt, connection)
@@ -414,14 +418,14 @@ def add_where(stmt):
414418
return stmt.where(*where_conditions) if where_conditions else stmt
415419

416420
if self.limit_sequences_per_category_to <= 0:
417-
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
421+
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
418422
else:
419423
subquery = sa.select(
420-
SqlSequenceAnnotation.sequence_name,
424+
self.sequence_annotations_type.sequence_name,
421425
sa.func.row_number()
422426
.over(
423427
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
424-
partition_by=SqlSequenceAnnotation.category,
428+
partition_by=self.sequence_annotations_type.category,
425429
)
426430
.label("row_number"),
427431
)
@@ -457,21 +461,23 @@ def _get_category_filters(self) -> List[sa.ColumnElement]:
457461
return []
458462

459463
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
460-
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
464+
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
461465

462466
def _get_pick_filters(self) -> List[sa.ColumnElement]:
463467
if not self.pick_sequences:
464468
return []
465469

466470
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
467-
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
471+
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
468472

469473
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
470474
if not self.exclude_sequences:
471475
return []
472476

473477
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
474-
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
478+
return [
479+
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
480+
]
475481

476482
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
477483
subsets = self.subsets

0 commit comments

Comments
 (0)