2
2
from itertools import chain
3
3
from typing import (
4
4
TYPE_CHECKING ,
5
+ Any ,
5
6
Dict ,
6
7
Generic ,
7
8
Iterable ,
17
18
from django .db .models import Expression , Q , QuerySet
18
19
from django .db .models .fields import NOT_PROVIDED
19
20
21
+ from .expressions import ExcludedCol
20
22
from .sql import PostgresInsertQuery , PostgresQuery
21
23
from .types import ConflictAction
22
24
@@ -51,6 +53,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
51
53
self .conflict_action = None
52
54
self .conflict_update_condition = None
53
55
self .index_predicate = None
56
+ self .update_values = None
54
57
55
58
def annotate (self , ** annotations ) -> "Self" : # type: ignore[valid-type, override]
56
59
"""Custom version of the standard annotate function that allows using
@@ -108,6 +111,7 @@ def on_conflict(
108
111
action : ConflictAction ,
109
112
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
110
113
update_condition : Optional [Union [Expression , Q , str ]] = None ,
114
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
111
115
):
112
116
"""Sets the action to take when conflicts arise when attempting to
113
117
insert/create a new row.
@@ -125,12 +129,18 @@ def on_conflict(
125
129
126
130
update_condition:
127
131
Only update if this SQL expression evaluates to true.
132
+
133
+ update_values:
134
+ Optionally, values/expressions to use when rows
135
+ conflict. If not specified, all columns specified
136
+ in the rows are updated with the values you specified.
128
137
"""
129
138
130
139
self .conflict_target = fields
131
140
self .conflict_action = action
132
141
self .conflict_update_condition = update_condition
133
142
self .index_predicate = index_predicate
143
+ self .update_values = update_values
134
144
135
145
return self
136
146
@@ -293,6 +303,7 @@ def upsert(
293
303
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
294
304
using : Optional [str ] = None ,
295
305
update_condition : Optional [Union [Expression , Q , str ]] = None ,
306
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
296
307
) -> int :
297
308
"""Creates a new record or updates the existing one with the specified
298
309
data.
@@ -315,6 +326,11 @@ def upsert(
315
326
update_condition:
316
327
Only update if this SQL expression evaluates to true.
317
328
329
+ update_values:
330
+ Optionally, values/expressions to use when rows
331
+ conflict. If not specified, all columns specified
332
+ in the rows are updated with the values you specified.
333
+
318
334
Returns:
319
335
The primary key of the row that was created/updated.
320
336
"""
@@ -324,6 +340,7 @@ def upsert(
324
340
ConflictAction .UPDATE ,
325
341
index_predicate = index_predicate ,
326
342
update_condition = update_condition ,
343
+ update_values = update_values ,
327
344
)
328
345
329
346
kwargs = {** fields , "using" : using }
@@ -336,6 +353,7 @@ def upsert_and_get(
336
353
index_predicate : Optional [Union [Expression , Q , str ]] = None ,
337
354
using : Optional [str ] = None ,
338
355
update_condition : Optional [Union [Expression , Q , str ]] = None ,
356
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
339
357
):
340
358
"""Creates a new record or updates the existing one with the specified
341
359
data and then gets the row.
@@ -358,6 +376,11 @@ def upsert_and_get(
358
376
update_condition:
359
377
Only update if this SQL expression evaluates to true.
360
378
379
+ update_values:
380
+ Optionally, values/expressions to use when rows
381
+ conflict. If not specified, all columns specified
382
+ in the rows are updated with the values you specified.
383
+
361
384
Returns:
362
385
The model instance representing the row
363
386
that was created/updated.
@@ -368,6 +391,7 @@ def upsert_and_get(
368
391
ConflictAction .UPDATE ,
369
392
index_predicate = index_predicate ,
370
393
update_condition = update_condition ,
394
+ update_values = update_values ,
371
395
)
372
396
373
397
kwargs = {** fields , "using" : using }
@@ -381,6 +405,7 @@ def bulk_upsert(
381
405
return_model : bool = False ,
382
406
using : Optional [str ] = None ,
383
407
update_condition : Optional [Union [Expression , Q , str ]] = None ,
408
+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
384
409
):
385
410
"""Creates a set of new records or updates the existing ones with the
386
411
specified data.
@@ -407,6 +432,11 @@ def bulk_upsert(
407
432
update_condition:
408
433
Only update if this SQL expression evaluates to true.
409
434
435
+ update_values:
436
+ Optionally, values/expressions to use when rows
437
+ conflict. If not specified, all columns specified
438
+ in the rows are updated with the values you specified.
439
+
410
440
Returns:
411
441
A list of either the dicts of the rows upserted, including the pk or
412
442
the models of the rows upserted
@@ -417,7 +447,9 @@ def bulk_upsert(
417
447
ConflictAction .UPDATE ,
418
448
index_predicate = index_predicate ,
419
449
update_condition = update_condition ,
450
+ update_values = update_values ,
420
451
)
452
+
421
453
return self .bulk_insert (rows , return_model , using = using )
422
454
423
455
def _create_model_instance (
@@ -505,15 +537,19 @@ def _build_insert_compiler(
505
537
)
506
538
507
539
# get the fields to be used during update/insert
508
- insert_fields , update_fields = self ._get_upsert_fields (first_row )
540
+ insert_fields , update_values = self ._get_upsert_fields (first_row )
541
+
542
+ # allow the user to override what should happen on update
543
+ if self .update_values is not None :
544
+ update_values = self .update_values
509
545
510
546
# build a normal insert query
511
547
query = PostgresInsertQuery (self .model )
512
548
query .conflict_action = self .conflict_action
513
549
query .conflict_target = self .conflict_target
514
550
query .conflict_update_condition = self .conflict_update_condition
515
551
query .index_predicate = self .index_predicate
516
- query .values (objs , insert_fields , update_fields )
552
+ query .insert_on_conflict_values (objs , insert_fields , update_values )
517
553
518
554
compiler = query .get_compiler (using )
519
555
return compiler
@@ -578,13 +614,13 @@ def _get_upsert_fields(self, kwargs):
578
614
579
615
model_instance = self .model (** kwargs )
580
616
insert_fields = []
581
- update_fields = []
617
+ update_values = {}
582
618
583
619
for field in model_instance ._meta .local_concrete_fields :
584
620
has_default = field .default != NOT_PROVIDED
585
621
if field .name in kwargs or field .column in kwargs :
586
622
insert_fields .append (field )
587
- update_fields . append (field )
623
+ update_values [ field . name ] = ExcludedCol (field . column )
588
624
continue
589
625
elif has_default :
590
626
insert_fields .append (field )
@@ -595,13 +631,13 @@ def _get_upsert_fields(self, kwargs):
595
631
# instead of a concrete field, we have to handle that
596
632
if field .primary_key is True and "pk" in kwargs :
597
633
insert_fields .append (field )
598
- update_fields . append (field )
634
+ update_values [ field . name ] = ExcludedCol (field . column )
599
635
continue
600
636
601
637
if self ._is_magical_field (model_instance , field , is_insert = True ):
602
638
insert_fields .append (field )
603
639
604
640
if self ._is_magical_field (model_instance , field , is_insert = False ):
605
- update_fields . append (field )
641
+ update_values [ field . name ] = ExcludedCol (field . column )
606
642
607
- return insert_fields , update_fields
643
+ return insert_fields , update_values
0 commit comments