Skip to content

Commit 5fcb63f

Browse files
committed
Add ability to set update values, fixes #56
1 parent cbf93a3 commit 5fcb63f

File tree

6 files changed

+201
-28
lines changed

6 files changed

+201
-28
lines changed

docs/source/conflict_handling.rst

+36
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj
232232
Q(name__gt=ExcludedCol('priority'))
233233
234234
235+
Update values
236+
"""""""""""""
237+
238+
Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert.
239+
240+
Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.<column>`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here.
241+
242+
.. warning::
243+
244+
Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified.
245+
246+
.. code-block:: python
247+
248+
from django.db.models import F
249+
250+
from psqlextra.expressions import ExcludedCol
251+
252+
(
253+
MyModel
254+
.objects
255+
.on_conflict(
256+
['name'],
257+
ConflictAction.UPDATE,
258+
update_values=dict(
259+
name=ExcludedCol('name'),
260+
count=F('count') + 1,
261+
),
262+
)
263+
.insert(
264+
name='henk',
265+
count=0,
266+
)
267+
)
268+
269+
270+
235271
ConflictAction.NOTHING
236272
**********************
237273

psqlextra/compiler.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def as_sql(self, *args, **kwargs):
104104
return append_caller_to_sql(sql), params
105105

106106
def _prepare_query_values(self):
107-
"""Extra prep on query values by converting dictionaries into.
108-
107+
"""Extra prep on query values by converting dictionaries into
109108
:see:HStoreValue expressions.
110109
111110
This allows putting expressions in a dictionary. The
@@ -234,13 +233,6 @@ def _rewrite_insert_on_conflict(
234233
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
235234
clause."""
236235

237-
update_columns = ", ".join(
238-
[
239-
"{0} = EXCLUDED.{0}".format(self.qn(field.column))
240-
for field in self.query.update_fields # type: ignore[attr-defined]
241-
]
242-
)
243-
244236
# build the conflict target, the columns to watch
245237
# for conflicts
246238
on_conflict_clause = self._build_on_conflict_clause()
@@ -254,10 +246,21 @@ def _rewrite_insert_on_conflict(
254246
rewritten_sql += f" WHERE {expr_sql}"
255247
params += tuple(expr_params)
256248

249+
# Fallback in case the user didn't specify any update values. We can still
250+
# make the query work if we switch to ConflictAction.NOTHING
251+
if (
252+
conflict_action == ConflictAction.UPDATE.value
253+
and not self.query.update_values
254+
):
255+
conflict_action = ConflictAction.NOTHING
256+
257257
rewritten_sql += f" DO {conflict_action}"
258258

259-
if conflict_action == "UPDATE":
260-
rewritten_sql += f" SET {update_columns}"
259+
if conflict_action == ConflictAction.UPDATE.value:
260+
set_sql, sql_params = self._build_set_statement()
261+
262+
rewritten_sql += f" SET {set_sql}"
263+
params += sql_params
261264

262265
if update_condition:
263266
expr_sql, expr_params = self._compile_expression(
@@ -270,6 +273,23 @@ def _rewrite_insert_on_conflict(
270273

271274
return (rewritten_sql, params)
272275

276+
def _build_set_statement(self) -> Tuple[str, tuple]:
277+
"""Builds the SET statement for the ON CONFLICT DO UPDATE clause.
278+
279+
This uses the update compiler to provide full compatibility with
280+
the standard Django's `update(...)`.
281+
"""
282+
283+
# Local import to work around the circular dependency between
284+
# the compiler and the queries.
285+
from .sql import PostgresUpdateQuery
286+
287+
query = self.query.chain(PostgresUpdateQuery)
288+
query.add_update_values(self.query.update_values)
289+
290+
sql, params = query.get_compiler(self.connection.alias).as_sql()
291+
return sql.split("SET")[1].split(" WHERE")[0], tuple(params)
292+
273293
def _build_on_conflict_clause(self):
274294
if django.VERSION >= (2, 2):
275295
from django.db.models.constraints import BaseConstraint

psqlextra/query.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import chain
33
from typing import (
44
TYPE_CHECKING,
5+
Any,
56
Dict,
67
Generic,
78
Iterable,
@@ -17,6 +18,7 @@
1718
from django.db.models import Expression, Q, QuerySet
1819
from django.db.models.fields import NOT_PROVIDED
1920

21+
from .expressions import ExcludedCol
2022
from .sql import PostgresInsertQuery, PostgresQuery
2123
from .types import ConflictAction
2224

@@ -51,6 +53,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
5153
self.conflict_action = None
5254
self.conflict_update_condition = None
5355
self.index_predicate = None
56+
self.update_values = None
5457

5558
def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override]
5659
"""Custom version of the standard annotate function that allows using
@@ -108,6 +111,7 @@ def on_conflict(
108111
action: ConflictAction,
109112
index_predicate: Optional[Union[Expression, Q, str]] = None,
110113
update_condition: Optional[Union[Expression, Q, str]] = None,
114+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
111115
):
112116
"""Sets the action to take when conflicts arise when attempting to
113117
insert/create a new row.
@@ -125,12 +129,18 @@ def on_conflict(
125129
126130
update_condition:
127131
Only update if this SQL expression evaluates to true.
132+
133+
update_values:
134+
Optionally, values/expressions to use when rows
135+
conflict. If not specified, all columns specified
136+
in the rows are updated with the values you specified.
128137
"""
129138

130139
self.conflict_target = fields
131140
self.conflict_action = action
132141
self.conflict_update_condition = update_condition
133142
self.index_predicate = index_predicate
143+
self.update_values = update_values
134144

135145
return self
136146

@@ -293,6 +303,7 @@ def upsert(
293303
index_predicate: Optional[Union[Expression, Q, str]] = None,
294304
using: Optional[str] = None,
295305
update_condition: Optional[Union[Expression, Q, str]] = None,
306+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
296307
) -> int:
297308
"""Creates a new record or updates the existing one with the specified
298309
data.
@@ -315,6 +326,11 @@ def upsert(
315326
update_condition:
316327
Only update if this SQL expression evaluates to true.
317328
329+
update_values:
330+
Optionally, values/expressions to use when rows
331+
conflict. If not specified, all columns specified
332+
in the rows are updated with the values you specified.
333+
318334
Returns:
319335
The primary key of the row that was created/updated.
320336
"""
@@ -324,6 +340,7 @@ def upsert(
324340
ConflictAction.UPDATE,
325341
index_predicate=index_predicate,
326342
update_condition=update_condition,
343+
update_values=update_values,
327344
)
328345

329346
kwargs = {**fields, "using": using}
@@ -336,6 +353,7 @@ def upsert_and_get(
336353
index_predicate: Optional[Union[Expression, Q, str]] = None,
337354
using: Optional[str] = None,
338355
update_condition: Optional[Union[Expression, Q, str]] = None,
356+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
339357
):
340358
"""Creates a new record or updates the existing one with the specified
341359
data and then gets the row.
@@ -358,6 +376,11 @@ def upsert_and_get(
358376
update_condition:
359377
Only update if this SQL expression evaluates to true.
360378
379+
update_values:
380+
Optionally, values/expressions to use when rows
381+
conflict. If not specified, all columns specified
382+
in the rows are updated with the values you specified.
383+
361384
Returns:
362385
The model instance representing the row
363386
that was created/updated.
@@ -368,6 +391,7 @@ def upsert_and_get(
368391
ConflictAction.UPDATE,
369392
index_predicate=index_predicate,
370393
update_condition=update_condition,
394+
update_values=update_values,
371395
)
372396

373397
kwargs = {**fields, "using": using}
@@ -381,6 +405,7 @@ def bulk_upsert(
381405
return_model: bool = False,
382406
using: Optional[str] = None,
383407
update_condition: Optional[Union[Expression, Q, str]] = None,
408+
update_values: Optional[Dict[str, Union[Any, Expression]]] = None,
384409
):
385410
"""Creates a set of new records or updates the existing ones with the
386411
specified data.
@@ -407,6 +432,11 @@ def bulk_upsert(
407432
update_condition:
408433
Only update if this SQL expression evaluates to true.
409434
435+
update_values:
436+
Optionally, values/expressions to use when rows
437+
conflict. If not specified, all columns specified
438+
in the rows are updated with the values you specified.
439+
410440
Returns:
411441
A list of either the dicts of the rows upserted, including the pk or
412442
the models of the rows upserted
@@ -417,7 +447,9 @@ def bulk_upsert(
417447
ConflictAction.UPDATE,
418448
index_predicate=index_predicate,
419449
update_condition=update_condition,
450+
update_values=update_values,
420451
)
452+
421453
return self.bulk_insert(rows, return_model, using=using)
422454

423455
def _create_model_instance(
@@ -505,15 +537,19 @@ def _build_insert_compiler(
505537
)
506538

507539
# get the fields to be used during update/insert
508-
insert_fields, update_fields = self._get_upsert_fields(first_row)
540+
insert_fields, update_values = self._get_upsert_fields(first_row)
541+
542+
# allow the user to override what should happen on update
543+
if self.update_values is not None:
544+
update_values = self.update_values
509545

510546
# build a normal insert query
511547
query = PostgresInsertQuery(self.model)
512548
query.conflict_action = self.conflict_action
513549
query.conflict_target = self.conflict_target
514550
query.conflict_update_condition = self.conflict_update_condition
515551
query.index_predicate = self.index_predicate
516-
query.values(objs, insert_fields, update_fields)
552+
query.insert_on_conflict_values(objs, insert_fields, update_values)
517553

518554
compiler = query.get_compiler(using)
519555
return compiler
@@ -578,13 +614,13 @@ def _get_upsert_fields(self, kwargs):
578614

579615
model_instance = self.model(**kwargs)
580616
insert_fields = []
581-
update_fields = []
617+
update_values = {}
582618

583619
for field in model_instance._meta.local_concrete_fields:
584620
has_default = field.default != NOT_PROVIDED
585621
if field.name in kwargs or field.column in kwargs:
586622
insert_fields.append(field)
587-
update_fields.append(field)
623+
update_values[field.name] = ExcludedCol(field.column)
588624
continue
589625
elif has_default:
590626
insert_fields.append(field)
@@ -595,13 +631,13 @@ def _get_upsert_fields(self, kwargs):
595631
# instead of a concrete field, we have to handle that
596632
if field.primary_key is True and "pk" in kwargs:
597633
insert_fields.append(field)
598-
update_fields.append(field)
634+
update_values[field.name] = ExcludedCol(field.column)
599635
continue
600636

601637
if self._is_magical_field(model_instance, field, is_insert=True):
602638
insert_fields.append(field)
603639

604640
if self._is_magical_field(model_instance, field, is_insert=False):
605-
update_fields.append(field)
641+
update_values[field.name] = ExcludedCol(field.column)
606642

607-
return insert_fields, update_fields
643+
return insert_fields, update_values

psqlextra/sql.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import Optional, Tuple
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import django
55

@@ -154,10 +154,14 @@ def __init__(self, *args, **kwargs):
154154
self.conflict_action = ConflictAction.UPDATE
155155
self.conflict_update_condition = None
156156
self.index_predicate = None
157-
158-
self.update_fields = []
159-
160-
def values(self, objs, insert_fields, update_fields=[]):
157+
self.update_values = {}
158+
159+
def insert_on_conflict_values(
160+
self,
161+
objs: List,
162+
insert_fields: List,
163+
update_values: Dict[str, Union[Any, Expression]] = {},
164+
):
161165
"""Sets the values to be used in this query.
162166
163167
Insert fields are fields that are definitely
@@ -176,12 +180,13 @@ def values(self, objs, insert_fields, update_fields=[]):
176180
insert_fields:
177181
The fields to use in the INSERT statement
178182
179-
update_fields:
180-
The fields to only use in the UPDATE statement.
183+
update_values:
184+
Expressions/values to use when a conflict
185+
occurs and an UPDATE is performed.
181186
"""
182187

183188
self.insert_values(insert_fields, objs, raw=False)
184-
self.update_fields = update_fields
189+
self.update_values = update_values
185190

186191
def get_compiler(self, using=None, connection=None):
187192
if using:

psqlextra/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class ConflictAction(Enum):
2828
def all(cls) -> List["ConflictAction"]:
2929
return [choice for choice in cls]
3030

31+
def __str__(self) -> str:
32+
return self.value
33+
3134

3235
class PostgresPartitioningMethod(StrEnum):
3336
"""Methods of partitioning supported by PostgreSQL 11.x native support for

0 commit comments

Comments
 (0)