Skip to content

Commit b0f1e52

Browse files
Merge pull request #55 from piercefreeman/feature/batch-upsert
Add batch upsert support
2 parents 045e924 + 22adf91 commit b0f1e52

File tree

2 files changed

+134
-23
lines changed

2 files changed

+134
-23
lines changed

iceaxe/__tests__/test_session.py

+108
Original file line numberDiff line numberDiff line change
@@ -1275,3 +1275,111 @@ async def test_batch_update_exceeds_parameters():
12751275
assert "SET" in first_call.args[0]
12761276
assert "WHERE" in first_call.args[0]
12771277
assert '"id"' in first_call.args[0]
1278+
1279+
1280+
@pytest.mark.asyncio
1281+
async def test_batch_upsert_exceeds_parameters():
1282+
"""
1283+
Test that upsert() correctly batches operations when we exceed Postgres parameter limits.
1284+
We'll create enough objects with enough fields that a single query would exceed PG_MAX_PARAMETERS.
1285+
"""
1286+
assert assert_expected_user_fields(UserDemo)
1287+
1288+
# Calculate how many objects we need to exceed the parameter limit
1289+
# Each object has 2 fields (name, email) in UserDemo
1290+
# So each object uses 2 parameters
1291+
objects_needed = (PG_MAX_PARAMETERS // 2) + 1
1292+
users = [
1293+
UserDemo(name=f"User {i}", email=f"user{i}@example.com")
1294+
for i in range(objects_needed)
1295+
]
1296+
1297+
# Mock the connection with dynamic results based on input
1298+
mock_conn = AsyncMock()
1299+
mock_conn.fetchmany = AsyncMock(
1300+
side_effect=lambda query, values_list: [
1301+
{"id": i, "name": f"User {i}", "email": f"user{i}@example.com"}
1302+
for i in range(len(values_list))
1303+
]
1304+
)
1305+
mock_conn.executemany = AsyncMock()
1306+
mock_conn.transaction = mock_transaction
1307+
1308+
db = DBConnection(mock_conn)
1309+
1310+
# Upsert the objects with all possible kwargs
1311+
result = await db.upsert(
1312+
users,
1313+
conflict_fields=(UserDemo.email,),
1314+
update_fields=(UserDemo.name,),
1315+
returning_fields=(UserDemo.id, UserDemo.name, UserDemo.email),
1316+
)
1317+
1318+
# We should have made at least 2 calls to fetchmany since we exceeded the parameter limit
1319+
assert len(mock_conn.fetchmany.mock_calls) >= 2
1320+
1321+
# Verify the structure of the first call
1322+
first_call = mock_conn.fetchmany.mock_calls[0]
1323+
assert "INSERT INTO" in first_call.args[0]
1324+
assert "ON CONFLICT" in first_call.args[0]
1325+
assert "DO UPDATE SET" in first_call.args[0]
1326+
assert "RETURNING" in first_call.args[0]
1327+
1328+
# Verify we got back the expected number of results
1329+
assert result is not None
1330+
assert len(result) == objects_needed
1331+
assert all(len(r) == 3 for r in result) # Each result should have id, name, email
1332+
1333+
1334+
@pytest.mark.asyncio
1335+
async def test_batch_upsert_multiple_with_real_db(db_connection: DBConnection):
1336+
"""
1337+
Integration test for upserting multiple objects at once with a real database connection.
1338+
Tests both insert and update scenarios in the same batch.
1339+
"""
1340+
await db_connection.conn.execute(
1341+
"""
1342+
ALTER TABLE userdemo
1343+
ADD CONSTRAINT email_unique UNIQUE (email)
1344+
"""
1345+
)
1346+
1347+
# Create initial set of users
1348+
initial_users = [
1349+
UserDemo(name="User 1", email="[email protected]"),
1350+
UserDemo(name="User 2", email="[email protected]"),
1351+
]
1352+
await db_connection.insert(initial_users)
1353+
1354+
# Create a mix of new and existing users for upsert
1355+
users_to_upsert = [
1356+
# These should update
1357+
UserDemo(name="Updated User 1", email="[email protected]"),
1358+
UserDemo(name="Updated User 2", email="[email protected]"),
1359+
# These should insert
1360+
UserDemo(name="User 3", email="[email protected]"),
1361+
UserDemo(name="User 4", email="[email protected]"),
1362+
]
1363+
1364+
result = await db_connection.upsert(
1365+
users_to_upsert,
1366+
conflict_fields=(UserDemo.email,),
1367+
update_fields=(UserDemo.name,),
1368+
returning_fields=(UserDemo.name, UserDemo.email),
1369+
)
1370+
1371+
# Verify we got all results back
1372+
assert result is not None
1373+
assert len(result) == 4
1374+
1375+
# Verify the database state
1376+
db_result = await db_connection.conn.fetch("SELECT * FROM userdemo ORDER BY email")
1377+
assert len(db_result) == 4
1378+
1379+
# Check that updates worked
1380+
assert db_result[0]["name"] == "Updated User 1"
1381+
assert db_result[1]["name"] == "Updated User 2"
1382+
1383+
# Check that inserts worked
1384+
assert db_result[2]["name"] == "User 3"
1385+
assert db_result[3]["name"] == "User 4"

iceaxe/session.py

+26-23
Original file line numberDiff line numberDiff line change
@@ -416,32 +416,35 @@ async def upsert(
416416
)
417417
query += f" RETURNING {returning_string}"
418418

419-
# Execute for each object
420-
for obj in model_objects:
421-
obj_values = obj.model_dump()
422-
values = [
423-
info.to_db_value(obj_values[field])
424-
for field, info in fields.items()
425-
]
426-
419+
# Execute in batches
420+
for batch_objects, values_list in self._batch_objects_and_values(
421+
model_objects, list(fields.keys()), fields
422+
):
427423
if returning_fields_cols:
428-
result = await self.conn.fetchrow(query, *values)
429-
if result:
430-
# Process returned values, deserializing JSON if needed
431-
processed_values = []
432-
for field in returning_fields_cols:
433-
value = result[field.key]
434-
if (
435-
value is not None
436-
and field.root_model.model_fields[field.key].is_json
437-
):
438-
value = json_loads(value)
439-
processed_values.append(value)
440-
results.append(tuple(processed_values))
424+
# For returning queries, we need to use fetchmany to get all results
425+
rows = await self.conn.fetchmany(query, values_list)
426+
for row in rows:
427+
if row:
428+
# Process returned values, deserializing JSON if needed
429+
processed_values = []
430+
for field in returning_fields_cols:
431+
value = row[field.key]
432+
if (
433+
value is not None
434+
and field.root_model.model_fields[
435+
field.key
436+
].is_json
437+
):
438+
value = json_loads(value)
439+
processed_values.append(value)
440+
results.append(tuple(processed_values))
441441
else:
442-
await self.conn.execute(query, *values)
442+
# For non-returning queries, we can use executemany
443+
await self.conn.executemany(query, values_list)
443444

444-
obj.clear_modified_attributes()
445+
# Clear modified state for successfully upserted objects
446+
for obj in batch_objects:
447+
obj.clear_modified_attributes()
445448

446449
self.modification_tracker.clear_status(objects)
447450

0 commit comments

Comments
 (0)