Skip to content

Commit 9a0f740

Browse files
authored
Implement Mechanism to Selectively Override Automatic Field Creations (#214)
This implements the overriding mechanism discussed in #209 . Changes include: - Add ORMField class - The main overridable parameters are: type, description, deprecation_reason and required - We can name fields differently using prop_name. This was preferred over name to avoid confusion / collision with graphene.Field parameters. - Add tests for all types of SQLAlchemy properties: columns, relationships, column properties, hybrid properties and and composite properties. - Cleanups and re-organize some tests.
1 parent 8ad1f75 commit 9a0f740

12 files changed

+691
-476
lines changed

graphene_sqlalchemy/converter.py

+59-59
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
ChoiceType = JSONType = ScalarListType = TSVectorType = object
1717

1818

19+
def _get_attr_resolver(attr_name):
20+
return lambda root, _info: getattr(root, attr_name, None)
21+
22+
1923
def get_column_doc(column):
2024
return getattr(column, "doc", None)
2125

@@ -24,43 +28,61 @@ def is_column_nullable(column):
2428
return bool(getattr(column, "nullable", True))
2529

2630

27-
def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory):
28-
direction = relationship.direction
29-
model = relationship.mapper.entity
31+
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, **field_kwargs):
32+
direction = relationship_prop.direction
33+
model = relationship_prop.mapper.entity
3034

3135
def dynamic_type():
3236
_type = registry.get_type_for_model(model)
37+
3338
if not _type:
3439
return None
35-
if direction == interfaces.MANYTOONE or not relationship.uselist:
36-
return Field(_type)
40+
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
41+
return Field(
42+
_type,
43+
resolver=_get_attr_resolver(relationship_prop.key),
44+
**field_kwargs
45+
)
3746
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
3847
if _type._meta.connection:
39-
return connection_field_factory(relationship, registry)
40-
return Field(List(_type))
48+
# TODO Add a way to override connection_field_factory
49+
return connection_field_factory(relationship_prop, registry, **field_kwargs)
50+
return Field(
51+
List(_type),
52+
**field_kwargs
53+
)
4154

4255
return Dynamic(dynamic_type)
4356

4457

45-
def convert_sqlalchemy_hybrid_method(hybrid_item):
46-
return String(description=getattr(hybrid_item, "__doc__", None), required=False)
58+
def convert_sqlalchemy_hybrid_method(hybrid_prop, prop_name, **field_kwargs):
59+
if 'type' not in field_kwargs:
60+
# TODO The default type should be dependent on the type of the property propety.
61+
field_kwargs['type'] = String
62+
63+
return Field(
64+
resolver=_get_attr_resolver(prop_name),
65+
**field_kwargs
66+
)
4767

4868

49-
def convert_sqlalchemy_composite(composite, registry):
50-
converter = registry.get_converter_for_composite(composite.composite_class)
69+
def convert_sqlalchemy_composite(composite_prop, registry):
70+
converter = registry.get_converter_for_composite(composite_prop.composite_class)
5171
if not converter:
5272
try:
5373
raise Exception(
5474
"Don't know how to convert the composite field %s (%s)"
55-
% (composite, composite.composite_class)
75+
% (composite_prop, composite_prop.composite_class)
5676
)
5777
except AttributeError:
5878
# handle fields that are not attached to a class yet (don't have a parent)
5979
raise Exception(
6080
"Don't know how to convert the composite field %r (%s)"
61-
% (composite, composite.composite_class)
81+
% (composite_prop, composite_prop.composite_class)
6282
)
63-
return converter(composite, registry)
83+
84+
# TODO Add a way to override composite fields default parameters
85+
return converter(composite_prop, registry)
6486

6587

6688
def _register_composite_class(cls, registry=None):
@@ -78,8 +100,16 @@ def inner(fn):
78100
convert_sqlalchemy_composite.register = _register_composite_class
79101

80102

81-
def convert_sqlalchemy_column(column, registry=None):
82-
return convert_sqlalchemy_type(getattr(column, "type", None), column, registry)
103+
def convert_sqlalchemy_column(column_prop, registry, **field_kwargs):
104+
column = column_prop.columns[0]
105+
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
106+
field_kwargs.setdefault('required', not is_column_nullable(column))
107+
field_kwargs.setdefault('description', get_column_doc(column))
108+
109+
return Field(
110+
resolver=_get_attr_resolver(column_prop.key),
111+
**field_kwargs
112+
)
83113

84114

85115
@singledispatch
@@ -101,93 +131,63 @@ def convert_sqlalchemy_type(type, column, registry=None):
101131
@convert_sqlalchemy_type.register(postgresql.CIDR)
102132
@convert_sqlalchemy_type.register(TSVectorType)
103133
def convert_column_to_string(type, column, registry=None):
104-
return String(
105-
description=get_column_doc(column), required=not (is_column_nullable(column))
106-
)
134+
return String
107135

108136

109137
@convert_sqlalchemy_type.register(types.DateTime)
110138
def convert_column_to_datetime(type, column, registry=None):
111139
from graphene.types.datetime import DateTime
112-
113-
return DateTime(
114-
description=get_column_doc(column), required=not (is_column_nullable(column))
115-
)
140+
return DateTime
116141

117142

118143
@convert_sqlalchemy_type.register(types.SmallInteger)
119144
@convert_sqlalchemy_type.register(types.Integer)
120145
def convert_column_to_int_or_id(type, column, registry=None):
121-
if column.primary_key:
122-
return ID(
123-
description=get_column_doc(column),
124-
required=not (is_column_nullable(column)),
125-
)
126-
else:
127-
return Int(
128-
description=get_column_doc(column),
129-
required=not (is_column_nullable(column)),
130-
)
146+
return ID if column.primary_key else Int
131147

132148

133149
@convert_sqlalchemy_type.register(types.Boolean)
134150
def convert_column_to_boolean(type, column, registry=None):
135-
return Boolean(
136-
description=get_column_doc(column), required=not (is_column_nullable(column))
137-
)
151+
return Boolean
138152

139153

140154
@convert_sqlalchemy_type.register(types.Float)
141155
@convert_sqlalchemy_type.register(types.Numeric)
142156
@convert_sqlalchemy_type.register(types.BigInteger)
143157
def convert_column_to_float(type, column, registry=None):
144-
return Float(
145-
description=get_column_doc(column), required=not (is_column_nullable(column))
146-
)
158+
return Float
147159

148160

149161
@convert_sqlalchemy_type.register(types.Enum)
150162
def convert_enum_to_enum(type, column, registry=None):
151-
return Field(
152-
lambda: enum_for_sa_enum(type, registry or get_global_registry()),
153-
description=get_column_doc(column),
154-
required=not (is_column_nullable(column)),
155-
)
163+
return lambda: enum_for_sa_enum(type, registry or get_global_registry())
156164

157165

166+
# TODO Make ChoiceType conversion consistent with other enums
158167
@convert_sqlalchemy_type.register(ChoiceType)
159168
def convert_choice_to_enum(type, column, registry=None):
160169
name = "{}_{}".format(column.table.name, column.name).upper()
161-
return Enum(name, type.choices, description=get_column_doc(column))
170+
return Enum(name, type.choices)
162171

163172

164173
@convert_sqlalchemy_type.register(ScalarListType)
165174
def convert_scalar_list_to_list(type, column, registry=None):
166-
return List(String, description=get_column_doc(column))
175+
return List(String)
167176

168177

169178
@convert_sqlalchemy_type.register(postgresql.ARRAY)
170179
def convert_postgres_array_to_list(_type, column, registry=None):
171-
graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
172-
inner_type = type(graphene_type)
173-
return List(
174-
inner_type,
175-
description=get_column_doc(column),
176-
required=not (is_column_nullable(column)),
177-
)
180+
inner_type = convert_sqlalchemy_type(column.type.item_type, column)
181+
return List(inner_type)
178182

179183

180184
@convert_sqlalchemy_type.register(postgresql.HSTORE)
181185
@convert_sqlalchemy_type.register(postgresql.JSON)
182186
@convert_sqlalchemy_type.register(postgresql.JSONB)
183187
def convert_json_to_string(type, column, registry=None):
184-
return JSONString(
185-
description=get_column_doc(column), required=not (is_column_nullable(column))
186-
)
188+
return JSONString
187189

188190

189191
@convert_sqlalchemy_type.register(JSONType)
190192
def convert_json_type_to_string(type, column, registry=None):
191-
return JSONString(
192-
description=get_column_doc(column), required=not (is_column_nullable(column))
193-
)
193+
return JSONString

graphene_sqlalchemy/enums.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import Column
1+
from sqlalchemy.orm import ColumnProperty
22
from sqlalchemy.types import Enum as SQLAlchemyEnumType
33

44
from graphene import Argument, Enum, List
@@ -69,11 +69,12 @@ def enum_for_field(obj_type, field_name):
6969
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
7070
if orm_field is None:
7171
raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name))
72-
if not isinstance(orm_field, Column):
72+
if not isinstance(orm_field, ColumnProperty):
7373
raise TypeError(
7474
"{}.{} does not map to model column".format(obj_type._meta.name, field_name)
7575
)
76-
sa_enum = orm_field.type
76+
column = orm_field.columns[0]
77+
sa_enum = column.type
7778
if not isinstance(sa_enum, SQLAlchemyEnumType):
7879
raise TypeError(
7980
"{}.{} does not map to enum column".format(obj_type._meta.name, field_name)
@@ -138,15 +139,16 @@ def sort_enum_for_object_type(
138139
if only_fields and field_name not in only_fields:
139140
continue
140141
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
141-
if not isinstance(orm_field, Column):
142+
if not isinstance(orm_field, ColumnProperty):
142143
continue
143-
if only_indexed and not (orm_field.primary_key or orm_field.index):
144+
column = orm_field.columns[0]
145+
if only_indexed and not (column.primary_key or column.index):
144146
continue
145-
asc_name = get_name(orm_field.name, True)
146-
asc_value = EnumValue(asc_name, orm_field.asc())
147-
desc_name = get_name(orm_field.name, False)
148-
desc_value = EnumValue(desc_name, orm_field.desc())
149-
if orm_field.primary_key:
147+
asc_name = get_name(column.name, True)
148+
asc_value = EnumValue(asc_name, column.asc())
149+
desc_name = get_name(column.name, False)
150+
desc_value = EnumValue(desc_name, column.desc())
151+
if column.primary_key:
150152
default.append(asc_value)
151153
members.extend(((asc_name, asc_value), (desc_name, desc_value)))
152154
enum = Enum(name, members)

graphene_sqlalchemy/fields.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,22 @@ def __init__(self, type, *args, **kwargs):
9797
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
9898

9999

100-
def default_connection_field_factory(relationship, registry):
100+
def default_connection_field_factory(relationship, registry, **field_kwargs):
101101
model = relationship.mapper.entity
102102
model_type = registry.get_type_for_model(model)
103-
return createConnectionField(model_type)
103+
return createConnectionField(model_type, **field_kwargs)
104104

105105

106106
# TODO Remove in next major version
107107
__connectionFactory = UnsortedSQLAlchemyConnectionField
108108

109109

110-
def createConnectionField(_type):
110+
def createConnectionField(_type, **field_kwargs):
111111
log.warning(
112112
'createConnectionField is deprecated and will be removed in the next '
113113
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
114114
)
115-
return __connectionFactory(_type)
115+
return __connectionFactory(_type, **field_kwargs)
116116

117117

118118
def registerConnectionFieldFactory(factoryMethod):

graphene_sqlalchemy/tests/conftest.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from sqlalchemy import create_engine
33
from sqlalchemy.orm import scoped_session, sessionmaker
44

5+
from ..converter import convert_sqlalchemy_composite
56
from ..registry import reset_global_registry
6-
from .models import Base
7+
from .models import Base, CompositeFullName
78

89
test_db_url = 'sqlite://' # use in-memory database for tests
910

@@ -12,6 +13,12 @@
1213
def reset_registry():
1314
reset_global_registry()
1415

16+
# Prevent tests that implicitly depend on Reporter from raising
17+
# Tests that explicitly depend on this behavior should re-register a converter
18+
@convert_sqlalchemy_composite.register(CompositeFullName)
19+
def convert_composite_class(composite, registry):
20+
pass
21+
1522

1623
@pytest.yield_fixture(scope="function")
1724
def session():

graphene_sqlalchemy/tests/models.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import enum
44

5-
from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table
5+
from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table,
6+
func, select)
67
from sqlalchemy.ext.declarative import declarative_base
7-
from sqlalchemy.orm import mapper, relationship
8+
from sqlalchemy.ext.hybrid import hybrid_property
9+
from sqlalchemy.orm import column_property, composite, mapper, relationship
810

911
PetKind = Enum("cat", "dog", name="pet_kind")
1012

@@ -39,22 +41,39 @@ class Pet(Base):
3941
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
4042

4143

44+
class CompositeFullName(object):
45+
def __init__(self, first_name, last_name):
46+
self.first_name = first_name
47+
self.last_name = last_name
48+
49+
def __composite_values__(self):
50+
return self.first_name, self.last_name
51+
52+
def __repr__(self):
53+
return "{} {}".format(self.first_name, self.last_name)
54+
55+
4256
class Reporter(Base):
4357
__tablename__ = "reporters"
58+
4459
id = Column(Integer(), primary_key=True)
45-
first_name = Column(String(30))
46-
last_name = Column(String(30))
47-
email = Column(String())
60+
first_name = Column(String(30), doc="First name")
61+
last_name = Column(String(30), doc="Last name")
62+
email = Column(String(), doc="Email")
4863
favorite_pet_kind = Column(PetKind)
4964
pets = relationship("Pet", secondary=association_table, backref="reporters")
5065
articles = relationship("Article", backref="reporter")
5166
favorite_article = relationship("Article", uselist=False)
5267

53-
# total = column_property(
54-
# select([
55-
# func.cast(func.count(PersonInfo.id), Float)
56-
# ])
57-
# )
68+
@hybrid_property
69+
def hybrid_prop(self):
70+
return self.first_name
71+
72+
column_prop = column_property(
73+
select([func.cast(func.count(id), Integer)]), doc="Column property"
74+
)
75+
76+
composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite")
5877

5978

6079
class Article(Base):

0 commit comments

Comments
 (0)