diff --git a/python/lib/db/decorators/int_bool.py b/python/lib/db/decorators/int_bool.py new file mode 100644 index 000000000..fba0e3580 --- /dev/null +++ b/python/lib/db/decorators/int_bool.py @@ -0,0 +1,33 @@ +from typing import Literal + +from sqlalchemy import Integer +from sqlalchemy.engine import Dialect +from sqlalchemy.types import TypeDecorator + + +class IntBool(TypeDecorator[bool]): + """ + Decorator for a database boolean integer type. + In SQL, the type will appear as 'int'. + In Python, the type will appear as a boolean. + """ + + impl = Integer + + def process_bind_param(self, value: bool | None, dialect: Dialect) -> Literal[0, 1] | None: + match value: + case True: + return 1 + case False: + return 0 + case None: + return None + + def process_result_value(self, value: Literal[0, 1] | None, dialect: Dialect) -> bool | None: + match value: + case 1: + return True + case 0: + return False + case None: + return None diff --git a/python/lib/db/models/cohort.py b/python/lib/db/models/cohort.py index 3968506cf..0df7cde44 100644 --- a/python/lib/db/models/cohort.py +++ b/python/lib/db/models/cohort.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Mapped, mapped_column from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool class DbCohort(Base): @@ -8,6 +9,6 @@ class DbCohort(Base): id : Mapped[int] = mapped_column('CohortID', primary_key=True) name : Mapped[str] = mapped_column('title') - use_edc : Mapped[bool | None] = mapped_column('useEDC') + use_edc : Mapped[bool | None] = mapped_column('useEDC', IntBool) window_difference : Mapped[str | None] = mapped_column('WindowDifference') recruitment_target : Mapped[int | None] = mapped_column('RecruitmentTarget') diff --git a/python/lib/db/models/config_setting.py b/python/lib/db/models/config_setting.py index b2b4cfa28..16c388b3d 100644 --- a/python/lib/db/models/config_setting.py +++ b/python/lib/db/models/config_setting.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Mapped, mapped_column from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool class DbConfigSetting(Base): @@ -9,8 +10,8 @@ class DbConfigSetting(Base): id : Mapped[int] = mapped_column('ID', primary_key=True) name : Mapped[str] = mapped_column('Name') description : Mapped[str | None] = mapped_column('Description') - visible : Mapped[bool | None] = mapped_column('Visible') - allow_multiple : Mapped[bool | None] = mapped_column('AllowMultiple') + visible : Mapped[bool | None] = mapped_column('Visible', IntBool) + allow_multiple : Mapped[bool | None] = mapped_column('AllowMultiple', IntBool) data_type : Mapped[str | None] = mapped_column('DataType') parent_id : Mapped[int | None] = mapped_column('Parent') label : Mapped[str | None] = mapped_column('Label') diff --git a/python/lib/db/models/dicom_archive.py b/python/lib/db/models/dicom_archive.py index f41818008..e5952b321 100644 --- a/python/lib/db/models/dicom_archive.py +++ b/python/lib/db/models/dicom_archive.py @@ -11,6 +11,7 @@ import lib.db.models.mri_violation_log as db_mri_violation_log import lib.db.models.session as db_session from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool class DbDicomArchive(Base): @@ -47,7 +48,7 @@ class DbDicomArchive(Base): create_info : Mapped[str | None] = mapped_column('CreateInfo') acquisition_metadata : Mapped[str] = mapped_column('AcquisitionMetadata') date_sent : Mapped[datetime | None] = mapped_column('DateSent') - pending_transfer : Mapped[bool] = mapped_column('PendingTransfer') + pending_transfer : Mapped[bool] = mapped_column('PendingTransfer', IntBool) series : Mapped[list['db_dicom_archive_series.DbDicomArchiveSeries']] \ = relationship('DbDicomArchiveSeries', back_populates='archive') diff --git a/python/lib/db/models/file.py b/python/lib/db/models/file.py index 287bc96b2..15f2c0b37 100644 --- a/python/lib/db/models/file.py +++ b/python/lib/db/models/file.py @@ -6,6 +6,7 @@ import lib.db.models.file_parameter as db_file_parameter import lib.db.models.session as db_session from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool from lib.db.decorators.int_datetime import IntDatetime @@ -29,7 +30,7 @@ class DbFile(Base): pipeline_date : Mapped[date | None] = mapped_column('PipelineDate') source_file_id : Mapped[int | None] = mapped_column('SourceFileID') process_protocol_id : Mapped[int | None] = mapped_column('ProcessProtocolID') - caveat : Mapped[bool | None] = mapped_column('Caveat') + caveat : Mapped[bool | None] = mapped_column('Caveat', IntBool) dicom_archive_id : Mapped[int | None] = mapped_column('TarchiveSource') hrrt_archive_id : Mapped[int | None] = mapped_column('HrrtArchiveID') scanner_id : Mapped[int | None] = mapped_column('ScannerID') diff --git a/python/lib/db/models/mri_upload.py b/python/lib/db/models/mri_upload.py index b9a8fcdae..bcdf79aec 100644 --- a/python/lib/db/models/mri_upload.py +++ b/python/lib/db/models/mri_upload.py @@ -7,6 +7,7 @@ import lib.db.models.dicom_archive as db_dicom_archive import lib.db.models.session as db_session from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool from lib.db.decorators.y_n_bool import YNBool @@ -18,16 +19,16 @@ class DbMriUpload(Base): upload_date : Mapped[datetime | None] = mapped_column('UploadDate') upload_location : Mapped[str] = mapped_column('UploadLocation') decompressed_location : Mapped[str] = mapped_column('DecompressedLocation') - insertion_complete : Mapped[bool] = mapped_column('InsertionComplete') - inserting : Mapped[bool | None] = mapped_column('Inserting') + insertion_complete : Mapped[bool] = mapped_column('InsertionComplete', IntBool) + inserting : Mapped[bool | None] = mapped_column('Inserting', IntBool) patient_name : Mapped[str] = mapped_column('PatientName') number_of_minc_inserted : Mapped[int | None] = mapped_column('number_of_mincInserted') number_of_minc_created : Mapped[int | None] = mapped_column('number_of_mincCreated') dicom_archive_id : Mapped[int | None] \ = mapped_column('TarchiveID', ForeignKey('tarchive.TarchiveID')) session_id : Mapped[int | None] = mapped_column('SessionID', ForeignKey('session.ID')) - is_candidate_info_validated : Mapped[bool | None] = mapped_column('IsCandidateInfoValidated') - is_dicom_archive_validated : Mapped[bool] = mapped_column('IsTarchiveValidated') + is_candidate_info_validated : Mapped[bool | None] = mapped_column('IsCandidateInfoValidated', IntBool) + is_dicom_archive_validated : Mapped[bool] = mapped_column('IsTarchiveValidated', IntBool) is_phantom : Mapped[bool] = mapped_column('IsPhantom', YNBool) dicom_archive : Mapped[Optional['db_dicom_archive.DbDicomArchive']] \ diff --git a/python/lib/db/models/notification_type.py b/python/lib/db/models/notification_type.py index e5454fdac..3845fcbbe 100644 --- a/python/lib/db/models/notification_type.py +++ b/python/lib/db/models/notification_type.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Mapped, mapped_column from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool class DbNotificationType(Base): @@ -8,5 +9,5 @@ class DbNotificationType(Base): id : Mapped[int] = mapped_column('NotificationTypeID', primary_key=True) name : Mapped[str] = mapped_column('Type') - private : Mapped[bool | None] = mapped_column('private') + private : Mapped[bool | None] = mapped_column('private', IntBool) description: Mapped[str | None] = mapped_column('Description') diff --git a/python/lib/db/models/parameter_type.py b/python/lib/db/models/parameter_type.py index 1d29d3e37..83efc1eb2 100644 --- a/python/lib/db/models/parameter_type.py +++ b/python/lib/db/models/parameter_type.py @@ -2,6 +2,7 @@ import lib.db.models.file_parameter as db_file_parameter from lib.db.base import Base +from lib.db.decorators.int_bool import IntBool class DbParameterType(Base): @@ -17,8 +18,8 @@ class DbParameterType(Base): source_field : Mapped[str | None] = mapped_column('SourceField') source_from : Mapped[str | None] = mapped_column('SourceFrom') source_condition : Mapped[str | None] = mapped_column('SourceCondition') - queryable : Mapped[bool | None] = mapped_column('Queryable') - is_file : Mapped[bool | None] = mapped_column('IsFile') + queryable : Mapped[bool | None] = mapped_column('Queryable', IntBool) + is_file : Mapped[bool | None] = mapped_column('IsFile', IntBool) file_parameters: Mapped[list['db_file_parameter.DbFileParameter']] \ = relationship('DbFileParameter', back_populates='type') diff --git a/python/tests/integration/test_orm_sql_sync.py b/python/tests/integration/test_orm_sql_sync.py index 41803cdfc..8016b1665 100644 --- a/python/tests/integration/test_orm_sql_sync.py +++ b/python/tests/integration/test_orm_sql_sync.py @@ -3,7 +3,7 @@ from typing import Any from sqlalchemy import MetaData -from sqlalchemy.dialects.mysql.types import DOUBLE, TINYINT +from sqlalchemy.dialects.mysql.types import DOUBLE from sqlalchemy.types import TypeDecorator, TypeEngine from lib.db.base import Base @@ -59,9 +59,6 @@ def get_orm_python_type(orm_type: TypeEngine[Any]): def get_sql_python_type(sql_type: TypeEngine[Any]): - if isinstance(sql_type, TINYINT) and sql_type.display_width == 1: # type: ignore - return bool - if isinstance(sql_type, DOUBLE): return float