Skip to content

Commit a823e52

Browse files
committed
feat: add param merge_fields to bulk_update_from_dicts
1 parent 2991b03 commit a823e52

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

examples/service/routers/account.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class UpdateIn(BaseModel):
192192
id: int = Field(...)
193193
active: bool = Field(...)
194194
gender: GenderEnum = Field(...)
195+
extend: dict
195196

196197
@router.post("/bulk_update")
197198
async def bulk_update_view(
@@ -201,6 +202,7 @@ async def bulk_update_view(
201202
[d.dict() for d in dicts],
202203
join_fields=["id"],
203204
update_fields=["active", "gender"],
205+
merge_fields=["extend"],
204206
using_values=True,
205207
)
206208
return {"row_cnt": row_cnt}

fastapi_esql/orm/base_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,15 @@ async def bulk_update_from_dicts(
185185
dicts: List[Dict[str, Any]],
186186
join_fields: List[str],
187187
update_fields: List[str],
188+
merge_fields: Optional[List[str]] = None,
188189
using_values: bool = True,
189190
):
190191
sql = SQLizer.bulk_update_from_dicts(
191192
cls.table,
192193
dicts,
193194
join_fields,
194195
update_fields,
196+
merge_fields,
195197
using_values,
196198
)
197199
return await CursorHandler.sum_row_cnt(sql, cls.rw_conn, logger)

fastapi_esql/utils/sqlizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def build_fly_table(
268268
dicts: List[Dict[str, Any]],
269269
fields: List[str],
270270
using_values: bool = True,
271+
log_sql: bool = True,
271272
) -> Optional[str]:
272273
if not all([dicts, fields]):
273274
raise WrongParamsError("Parameters `dicts`, `fields` are required")
@@ -291,7 +292,8 @@ def build_fly_table(
291292
SELECT * FROM (
292293
{values}
293294
) AS {table}"""
294-
logger.debug(sql)
295+
if log_sql:
296+
logger.debug(sql)
295297
return sql
296298

297299
@classmethod
@@ -301,17 +303,22 @@ def bulk_update_from_dicts(
301303
dicts: List[Dict[str, Any]],
302304
join_fields: List[str],
303305
update_fields: List[str],
306+
merge_fields: Optional[List[str]] = None,
304307
using_values: bool = True,
305308
) -> Optional[str]:
306309
if not all([table, dicts, join_fields, update_fields]):
307310
raise WrongParamsError("Parameters `table`, `dicts`, `join_fields`, `update_fields` are required")
308311

309312
joins = [f"{table}.{jf}=tmp.{jf}" for jf in join_fields]
310313
updates = [f"{table}.{uf}=tmp.{uf}" for uf in update_fields]
314+
merge_fields = merge_fields or []
315+
for mf in merge_fields:
316+
dict_obj = f"COALESCE({table}.{mf}, '{{}}')"
317+
updates.append(f"{table}.{mf}=JSON_MERGE_PATCH({dict_obj}, tmp.{mf})")
311318

312319
sql = f"""
313320
UPDATE {table}
314-
JOIN ({SQLizer.build_fly_table(dicts, join_fields + update_fields, using_values)}
321+
JOIN ({SQLizer.build_fly_table(dicts, join_fields + update_fields + merge_fields, using_values, log_sql=False)}
315322
) tmp ON {", ".join(joins)}
316323
SET {", ".join(updates)}
317324
"""

0 commit comments

Comments
 (0)