diff --git a/sqlalchemy_utils/types/pg_composite.py b/sqlalchemy_utils/types/pg_composite.py index a80b95e9..78c66036 100644 --- a/sqlalchemy_utils/types/pg_composite.py +++ b/sqlalchemy_utils/types/pg_composite.py @@ -174,7 +174,10 @@ class CompositeType(UserDefinedType, SchemaType): class comparator_factory(UserDefinedType.Comparator): def __getattr__(self, key): try: - type_ = self.type.typemap[key] + if key in self.type.column_map: + type_ = self.type.column_map[key].type + else: + type_ = self.type.typemap[key] except KeyError: raise KeyError( "Type '%s' doesn't have an attribute: '%s'" % ( @@ -192,6 +195,7 @@ def __init__(self, name, columns): SchemaType.__init__(self) self.name = name self.columns = columns + self.column_map = {col.name: col for col in columns} if name in registered_composites: self.type_cls = registered_composites[name].type_cls else: diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 23933019..eb352a14 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -99,6 +99,18 @@ def test_incomplete_dict(self, session, Account): assert account.balance.currency is None assert account.balance.amount == 15 + def test_composite_attribute_query(self, session, Account): + account_query = session \ + .query(Account) \ + .filter(Account.balance.amount == 15) + + raw_query = str( + account_query.statement + .compile(compile_kwargs={"literal_binds": True}) + ) + assert '(account.balance).amount = 15' in raw_query + + @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('postgresql_dsn')