Skip to content

Commit 92ae690

Browse files
committed
Tolerate new columns being added to tables during upsert
Up until now, upserting to a table which had a column added not know to Django would make the query crash. This commit introduces a more robust mechanism to constructing model instances from query results that tolerates a column being added at the end of the table.
1 parent 5fcb63f commit 92ae690

File tree

8 files changed

+744
-107
lines changed

8 files changed

+744
-107
lines changed

psqlextra/compiler.py

+11-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44

55
from collections.abc import Iterable
6-
from typing import Tuple, Union
6+
from typing import TYPE_CHECKING, Tuple, Union, cast
77

88
import django
99

@@ -12,11 +12,13 @@
1212
from django.db.models import Expression, Model, Q
1313
from django.db.models.fields.related import RelatedField
1414
from django.db.models.sql import compiler as django_compiler
15-
from django.db.utils import ProgrammingError
1615

1716
from .expressions import HStoreValue
1817
from .types import ConflictAction
1918

19+
if TYPE_CHECKING:
20+
from .sql import PostgresInsertQuery
21+
2022

2123
def append_caller_to_sql(sql):
2224
"""Append the caller to SQL queries.
@@ -161,6 +163,8 @@ def as_sql(self, *args, **kwargs):
161163
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
162164
"""Compiler for SQL INSERT statements."""
163165

166+
query: "PostgresInsertQuery"
167+
164168
def __init__(self, *args, **kwargs):
165169
"""Initializes a new instance of
166170
:see:PostgresInsertOnConflictCompiler."""
@@ -169,35 +173,14 @@ def __init__(self, *args, **kwargs):
169173

170174
def as_sql(self, return_id=False, *args, **kwargs):
171175
"""Builds the SQL INSERT statement."""
176+
172177
queries = [
173178
self._rewrite_insert(sql, params, return_id)
174179
for sql, params in super().as_sql(*args, **kwargs)
175180
]
176181

177182
return queries
178183

179-
def execute_sql(self, return_id=False):
180-
# execute all the generate queries
181-
with self.connection.cursor() as cursor:
182-
rows = []
183-
for sql, params in self.as_sql(return_id):
184-
cursor.execute(sql, params)
185-
try:
186-
rows.extend(cursor.fetchall())
187-
except ProgrammingError:
188-
pass
189-
description = cursor.description
190-
191-
# create a mapping between column names and column value
192-
return [
193-
{
194-
column.name: row[column_index]
195-
for column_index, column in enumerate(description)
196-
if row
197-
}
198-
for row in rows
199-
]
200-
201184
def _rewrite_insert(self, sql, params, return_id=False):
202185
"""Rewrites a formed SQL INSERT query to include the ON CONFLICT
203186
clause.
@@ -209,9 +192,9 @@ def _rewrite_insert(self, sql, params, return_id=False):
209192
params:
210193
The parameters passed to the query.
211194
212-
returning:
213-
What to put in the `RETURNING` clause
214-
of the resulting query.
195+
return_id:
196+
Whether to only return the ID or all
197+
columns.
215198
216199
Returns:
217200
A tuple of the rewritten SQL query and new params.
@@ -284,7 +267,7 @@ def _build_set_statement(self) -> Tuple[str, tuple]:
284267
# the compiler and the queries.
285268
from .sql import PostgresUpdateQuery
286269

287-
query = self.query.chain(PostgresUpdateQuery)
270+
query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery))
288271
query.add_update_values(self.query.update_values)
289272

290273
sql, params = query.get_compiler(self.connection.alias).as_sql()

psqlextra/introspect/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .fields import inspect_model_local_concrete_fields
2+
from .models import model_from_cursor, models_from_cursor
3+
4+
__all__ = [
5+
"models_from_cursor",
6+
"model_from_cursor",
7+
"inspect_model_local_concrete_fields",
8+
]

psqlextra/introspect/fields.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List, Type
2+
3+
from django.db.models import Field, Model
4+
5+
6+
def inspect_model_local_concrete_fields(model: Type[Model]) -> List[Field]:
7+
"""Gets a complete list of local and concrete fields on a model, these are
8+
fields that directly map to a database colmn directly on the table backing
9+
the model.
10+
11+
This is similar to Django's `Meta.local_concrete_fields`, which is a
12+
private API. This method utilizes only public APIs.
13+
"""
14+
15+
local_concrete_fields = []
16+
17+
for field in model._meta.get_fields(include_parents=False):
18+
if isinstance(field, Field) and field.column and not field.many_to_many:
19+
local_concrete_fields.append(field)
20+
21+
return local_concrete_fields

psqlextra/introspect/models.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from typing import (
2+
Any,
3+
Dict,
4+
Generator,
5+
Iterable,
6+
List,
7+
Optional,
8+
Type,
9+
TypeVar,
10+
cast,
11+
)
12+
13+
from django.core.exceptions import FieldDoesNotExist
14+
from django.db import connection, models
15+
from django.db.models import Field, Model
16+
from django.db.models.expressions import Expression
17+
18+
from .fields import inspect_model_local_concrete_fields
19+
20+
TModel = TypeVar("TModel", bound=models.Model)
21+
22+
23+
def _construct_model(
24+
model: Type[TModel],
25+
columns: Iterable[str],
26+
values: Iterable[Any],
27+
*,
28+
apply_converters: bool = True
29+
) -> TModel:
30+
fields_by_name_and_column = {}
31+
for field in inspect_model_local_concrete_fields(model):
32+
fields_by_name_and_column[field.attname] = field
33+
34+
if field.db_column:
35+
fields_by_name_and_column[field.db_column] = field
36+
37+
indexable_columns = list(columns)
38+
39+
row = {}
40+
41+
for index, value in enumerate(values):
42+
column = indexable_columns[index]
43+
try:
44+
field = cast(Field, model._meta.get_field(column))
45+
except FieldDoesNotExist:
46+
field = fields_by_name_and_column[column]
47+
48+
field_column_expression = field.get_col(model._meta.db_table)
49+
50+
if apply_converters:
51+
converters = cast(Expression, field).get_db_converters(
52+
connection
53+
) + connection.ops.get_db_converters(field_column_expression)
54+
55+
converted_value = value
56+
for converter in converters:
57+
converted_value = converter(
58+
converted_value,
59+
field_column_expression,
60+
connection,
61+
)
62+
else:
63+
converted_value = value
64+
65+
row[field.attname] = converted_value
66+
67+
instance = model(**row)
68+
instance._state.adding = False
69+
instance._state.db = connection.alias
70+
71+
return instance
72+
73+
74+
def models_from_cursor(
75+
model: Type[TModel], cursor, *, related_fields: List[str] = []
76+
) -> Generator[TModel, None, None]:
77+
"""Fetches all rows from a cursor and converts the values into model
78+
instances.
79+
80+
This is roughly what Django does internally when you do queries. This
81+
goes further than `Model.from_db` as it also applies converters to make
82+
sure that values are converted into their Python equivalent.
83+
84+
Use this when you've outgrown the ORM and you are writing performant
85+
queries yourself and you need to map the results back into ORM objects.
86+
87+
Arguments:
88+
model:
89+
Model to construct.
90+
91+
cursor:
92+
Cursor to read the rows from.
93+
94+
related_fields:
95+
List of ForeignKey/OneToOneField names that were joined
96+
into the raw query. Use this to achieve the same thing
97+
that Django's `.select_related()` does.
98+
99+
Field names should be specified in the order that they
100+
are SELECT'd in.
101+
"""
102+
103+
columns = [col[0] for col in cursor.description]
104+
field_offset = len(inspect_model_local_concrete_fields(model))
105+
106+
rows = cursor.fetchmany()
107+
108+
while rows:
109+
for values in rows:
110+
instance = _construct_model(
111+
model, columns[:field_offset], values[:field_offset]
112+
)
113+
114+
for index, related_field_name in enumerate(related_fields):
115+
related_model = model._meta.get_field(
116+
related_field_name
117+
).related_model
118+
if not related_model:
119+
continue
120+
121+
related_field_count = len(
122+
inspect_model_local_concrete_fields(related_model)
123+
)
124+
125+
# autopep8: off
126+
related_columns = columns[
127+
field_offset : field_offset + related_field_count # noqa
128+
]
129+
related_values = values[
130+
field_offset : field_offset + related_field_count # noqa
131+
]
132+
# autopep8: one
133+
134+
if (
135+
not related_columns
136+
or not related_values
137+
or all([value is None for value in related_values])
138+
):
139+
continue
140+
141+
related_instance = _construct_model(
142+
cast(Type[Model], related_model),
143+
related_columns,
144+
related_values,
145+
)
146+
instance._state.fields_cache[related_field_name] = related_instance # type: ignore
147+
148+
field_offset += len(
149+
inspect_model_local_concrete_fields(related_model)
150+
)
151+
152+
yield instance
153+
154+
rows = cursor.fetchmany()
155+
156+
157+
def model_from_cursor(
158+
model: Type[TModel], cursor, *, related_fields: List[str] = []
159+
) -> Optional[TModel]:
160+
return next(
161+
models_from_cursor(model, cursor, related_fields=related_fields), None
162+
)
163+
164+
165+
def model_from_dict(
166+
model: Type[TModel], row: Dict[str, Any], *, apply_converters: bool = True
167+
) -> TModel:
168+
return _construct_model(
169+
model, row.keys(), row.values(), apply_converters=apply_converters
170+
)

0 commit comments

Comments
 (0)