Skip to content

Commit b219f68

Browse files
committed
feat: add using_values to upsert_on_duplicated and build_fly_table
1 parent 50048d5 commit b219f68

File tree

4 files changed

+41
-15
lines changed

4 files changed

+41
-15
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ await AccountMgr.upsert_on_duplicated(
115115
],
116116
insert_fields=["id", "gender", "name", "locale", "extend"],
117117
upsert_fields=["name", "locale"],
118+
using_values=False,
118119
)
119120
```
120121
Generate sql and execute
@@ -123,8 +124,7 @@ Generate sql and execute
123124
(id, gender, name, locale, extend)
124125
VALUES
125126
(7, 1, '斉藤 修平', 'ja_JP', '{}'), (8, 1, 'Ojas Salvi', 'en_IN', '{}'), (9, 1, '羊淑兰', 'zh_CN', '{}')
126-
AS `new_account`
127-
ON DUPLICATE KEY UPDATE name=`new_account`.name, locale=`new_account`.locale
127+
AS `new_account` ON DUPLICATE KEY UPDATE name=`new_account`.name, locale=`new_account`.locale
128128
```
129129

130130
### **insert_into_select**
@@ -158,6 +158,7 @@ await AccountMgr.bulk_update_with_fly_table(
158158
],
159159
join_fields=["id"],
160160
update_fields=["active", "gender"],
161+
using_values=True,
161162
)
162163
```
163164
Generate sql and execute
@@ -166,7 +167,7 @@ Generate sql and execute
166167
JOIN (
167168
SELECT * FROM (
168169
VALUES
169-
ROW(7, False, 1), ROW(15, True, 0)
170+
ROW(7, False, 1), ROW(15, True, 0)
170171
) AS fly_table (id, active, gender)
171172
) tmp ON account.id=tmp.id
172173
SET account.active=tmp.active, account.gender=tmp.gender

examples/service/routers/account.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ async def bulk_upsert_view():
153153
dicts,
154154
insert_fields=["id", "gender", "name", "locale", "extend"],
155155
upsert_fields=["name", "locale"],
156+
using_values=False,
156157
)
157158
return {"row_cnt": row_cnt}
158159

@@ -187,5 +188,6 @@ async def bulk_update_view(
187188
[d.dict() for d in dicts],
188189
join_fields=["id"],
189190
update_fields=["active", "gender"],
191+
using_values=True,
190192
)
191193
return {"row_cnt": row_cnt}

fastapi_esql/orm/base_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ async def upsert_on_duplicated(
110110
dicts: List[Dict[str, Any]],
111111
insert_fields: List[str],
112112
upsert_fields: List[str],
113+
using_values: bool = False,
113114
):
114115
sql = SQLizer.upsert_on_duplicated(
115116
cls.table,
116117
dicts,
117118
insert_fields,
118119
upsert_fields,
120+
using_values,
119121
)
120122
return await CursorHandler.sum_row_cnt(sql, cls.rw_conn, logger)
121123

@@ -143,11 +145,13 @@ async def bulk_update_with_fly_table(
143145
dicts: List[Dict[str, Any]],
144146
join_fields: List[str],
145147
update_fields: List[str],
148+
using_values: bool = True,
146149
):
147150
sql = SQLizer.bulk_update_with_fly_table(
148151
cls.table,
149152
dicts,
150153
join_fields,
151154
update_fields,
155+
using_values,
152156
)
153157
return await CursorHandler.sum_row_cnt(sql, cls.rw_conn, logger)

fastapi_esql/utils/sqlizer.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def upsert_on_duplicated(
164164
dicts: List[Dict[str, Any]],
165165
insert_fields: List[str],
166166
upsert_fields: List[str],
167+
using_values: bool = False,
167168
) -> Optional[str]:
168169
if not all([table, dicts, insert_fields, upsert_fields]):
169170
raise WrongParamsError("Please check your params")
@@ -172,16 +173,23 @@ def upsert_on_duplicated(
172173
f"({', '.join(cls._sqlize_value(d.get(f)) for f in insert_fields)})"
173174
for d in dicts
174175
]
175-
new_table = f"`new_{table}`"
176-
upserts = [f"{field}={new_table}.{field}" for field in upsert_fields]
176+
# NOTE Beginning with MySQL 8.0.19, it is possible to use an alias for the row
177+
# https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
178+
if using_values:
179+
upserts = [f"{field}=VALUES({field})" for field in upsert_fields]
180+
on_duplicated = f"ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
181+
else:
182+
new_table = f"`new_{table}`"
183+
upserts = [f"{field}={new_table}.{field}" for field in upsert_fields]
184+
on_duplicated = f"AS {new_table} ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
177185

178186
sql = f"""
179187
INSERT INTO {table}
180188
({", ".join(insert_fields)})
181189
VALUES
182190
{", ".join(values)}
183-
AS {new_table}
184-
ON DUPLICATE KEY UPDATE {", ".join(upserts)}"""
191+
{on_duplicated}
192+
"""
185193
logger.debug(sql)
186194
return sql
187195

@@ -218,20 +226,30 @@ def build_fly_table(
218226
cls,
219227
dicts: List[Dict[str, Any]],
220228
fields: List[str],
229+
using_values: bool = True,
221230
) -> Optional[str]:
222231
if not all([dicts, fields]):
223232
raise WrongParamsError("Please check your params")
224233

225-
rows = [
226-
f"ROW({', '.join(cls._sqlize_value(d.get(f)) for f in fields)})"
227-
for d in dicts
228-
]
234+
if using_values:
235+
rows = [
236+
f"ROW({', '.join(cls._sqlize_value(d.get(f)) for f in fields)})"
237+
for d in dicts
238+
]
239+
values = "VALUES\n " + ", ".join(rows)
240+
table = f"fly_table ({', '.join(fields)})"
241+
else:
242+
rows = [
243+
f"SELECT {', '.join(f'{cls._sqlize_value(d.get(f))} {f}' for f in fields)}"
244+
for d in dicts
245+
]
246+
values = " UNION ".join(rows)
247+
table = "fly_table"
229248

230249
sql = f"""
231250
SELECT * FROM (
232-
VALUES
233-
{", ".join(rows)}
234-
) AS fly_table ({", ".join(fields)})"""
251+
{values}
252+
) AS {table}"""
235253
logger.debug(sql)
236254
return sql
237255

@@ -242,6 +260,7 @@ def bulk_update_with_fly_table(
242260
dicts: List[Dict[str, Any]],
243261
join_fields: List[str],
244262
update_fields: List[str],
263+
using_values: bool = True,
245264
) -> Optional[str]:
246265
if not all([table, dicts, join_fields, update_fields]):
247266
raise WrongParamsError("Please check your params")
@@ -251,7 +270,7 @@ def bulk_update_with_fly_table(
251270

252271
sql = f"""
253272
UPDATE {table}
254-
JOIN ({SQLizer.build_fly_table(dicts, join_fields + update_fields)}
273+
JOIN ({SQLizer.build_fly_table(dicts, join_fields + update_fields, using_values)}
255274
) tmp ON {", ".join(joins)}
256275
SET {", ".join(updates)}"""
257276
logger.debug(sql)

0 commit comments

Comments
 (0)