Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ jobs:
"cockroach:latest-v24.1",
"cockroach:latest-v24.2",
"cockroach:latest-v24.3",
"cockroach:latest-v25.1"
"cockroach:latest-v25.1",
"cockroach:latest-v25.2"
]
db-alias: [
"psycopg2",
Expand Down
22 changes: 22 additions & 0 deletions cockroach_helper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash
COCKROACHDB=cockroach-v23.1.13.linux-amd64
CROACHDB=~/.cache/$COCKROACHDB/cockroach

quit_cockroachdb() {
OLDPIDNS=$(ps -o pidns -C cockroach | awk 'NR==2 {print $0}')
if [ -n "$OLDPIDNS" ]; then
pkill --ns $$ $OLDPIDNS
fi
return 0
}

[ -n "$HOST" ] || HOST=localhost
mkdir -p $(dirname $CROACHDB)
[[ -f "$CROACHDB" ]] || wget -qO- https://binaries.cockroachdb.com/$COCKROACHDB.tgz | tar xvz --directory ~/.cache
if [ $1 == "start" ]; then
quit_cockroachdb
$CROACHDB start-single-node --background --insecure --store=type=mem,size=10% --log-dir /tmp/ --listen-addr=$HOST:26257 --http-addr=$HOST:26301
#$CROACHDB sql --host=$HOST:26257 --insecure -e "set sql_safe_updates=false; drop database if exists apibuilder; create database if not exists apibuilder; create user if not exists apibuilder; grant all on database apibuilder to apibuilder;"
else
quit_cockroachdb
fi
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[tool.black]
line-length = 100

[tool.pytest.ini_options]
addopts = "--tb native -v -r sfxX --maxfail=250 -p warnings -p logging --strict-markers"
markers = [
Expand Down
118 changes: 76 additions & 42 deletions sqlalchemy_cockroachdb/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@
from .base import savepoint_state


def run_transaction(transactor, callback, max_retries=None, max_backoff=0):
class ChainTransaction:
def __init__(self, transactions=None):
self.results = []
self.transactions = transactions or []

def add_result(self, result):
self.results.append(result)


def run_transaction(transactor, callback, max_retries=None, max_backoff=0, **kwargs):
"""Run a transaction with retries.

``callback()`` will be called with one argument to execute the
Expand All @@ -26,15 +35,18 @@ def run_transaction(transactor, callback, max_retries=None, max_backoff=0):
transaction should be retried before giving up.
``max_backoff`` is an optional integer that specifies the capped number of seconds
for the exponential back-off.
``inject_error`` forces retry loop to run via SET inject_retry_errors_enabled = 'true'
``use_cockroach_restart``, default true, utilizes the special cockroach_restart protocol,
as outlined in: https://www.cockroachlabs.com/blog/nested-transactions-in-cockroachdb-20-1/
"""
if isinstance(transactor, (sqlalchemy.engine.Connection, sqlalchemy.orm.Session)):
return _txn_retry_loop(transactor, callback, max_retries, max_backoff)
return _txn_retry_loop(transactor, callback, max_retries, max_backoff, **kwargs)
elif isinstance(transactor, sqlalchemy.engine.Engine):
with transactor.connect() as connection:
return _txn_retry_loop(connection, callback, max_retries, max_backoff)
return _txn_retry_loop(connection, callback, max_retries, max_backoff, **kwargs)
elif isinstance(transactor, sqlalchemy.orm.sessionmaker):
session = transactor()
return _txn_retry_loop(session, callback, max_retries, max_backoff)
return _txn_retry_loop(session, callback, max_retries, max_backoff, **kwargs)
else:
raise TypeError("don't know how to run a transaction on %s", type(transactor))

Expand All @@ -46,27 +58,32 @@ class _NestedTransaction:
loop to be rewritten by the dialect.
"""

def __init__(self, conn):
def __init__(self, conn, use_cockroach_restart=True):
self.conn = conn
self.use_cockroach_restart = use_cockroach_restart

def __enter__(self):
try:
savepoint_state.cockroach_restart = True
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = True
self.txn = self.conn.begin_nested()
if isinstance(self.conn, sqlalchemy.orm.Session):
if self.use_cockroach_restart and isinstance(self.conn, sqlalchemy.orm.Session):
# Sessions are lazy and don't execute the savepoint
# query until you ask for the connection.
self.conn.connection()
finally:
savepoint_state.cockroach_restart = False
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = False
return self

def __exit__(self, typ, value, tb):
try:
savepoint_state.cockroach_restart = True
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = True
self.txn.__exit__(typ, value, tb)
finally:
savepoint_state.cockroach_restart = False
if self.use_cockroach_restart:
savepoint_state.cockroach_restart = False


def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None:
Expand All @@ -81,45 +98,62 @@ def retry_exponential_backoff(retry_count: int, max_backoff: int = 0) -> None:
:return: None
"""

sleep_secs = uniform(0, min(max_backoff, 0.1 * (2 ** retry_count)))
sleep_secs = uniform(0, min(max_backoff, 0.1 * (2**retry_count)))
sleep(sleep_secs)


def _txn_retry_loop(conn, callback, max_retries, max_backoff):
"""Inner transaction retry loop.

``conn`` may be either a Connection or a Session, but they both
have compatible ``begin()`` and ``begin_nested()`` methods.
"""
def run_in_nested_transaction(
conn, callback, max_retries, max_backoff, inject_error=False, **kwargs
):
if isinstance(conn, sqlalchemy.orm.Session):
dbapi_name = conn.bind.driver
else:
dbapi_name = conn.engine.driver

retry_count = 0
with conn.begin():
while True:
try:
with _NestedTransaction(conn):
ret = callback(conn)
return ret
except sqlalchemy.exc.DatabaseError as e:
if max_retries is not None and retry_count >= max_retries:
raise
do_retry = False
if dbapi_name == "psycopg2":
import psycopg2
import psycopg2.errorcodes
if isinstance(e.orig, psycopg2.OperationalError):
if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE:
do_retry = True
else:
import psycopg
if isinstance(e.orig, psycopg.errors.SerializationFailure):
do_retry = True
if do_retry:
retry_count += 1
if max_backoff > 0:
retry_exponential_backoff(retry_count, max_backoff)
continue
while True:
if inject_error and retry_count == 0:
conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'true'"))
elif inject_error:
conn.execute(sqlalchemy.text("SET inject_retry_errors_enabled = 'false'"))
try:
with _NestedTransaction(conn, **kwargs):
return callback(conn)
except sqlalchemy.exc.DatabaseError as e:
if max_retries is not None and retry_count >= max_retries:
raise
do_retry = False
if dbapi_name == "psycopg2":
import psycopg2
import psycopg2.errorcodes

if isinstance(e.orig, psycopg2.OperationalError):
if e.orig.pgcode == psycopg2.errorcodes.SERIALIZATION_FAILURE:
do_retry = True
else:
import psycopg

if isinstance(e.orig, psycopg.errors.SerializationFailure):
do_retry = True
if do_retry:
retry_count += 1
if max_backoff > 0:
retry_exponential_backoff(retry_count, max_backoff)
continue
raise


def _txn_retry_loop(conn, callback, max_retries, max_backoff, **kwargs):
"""Inner transaction retry loop.

``conn`` may be either a Connection or a Session, but they both
have compatible ``begin()`` and ``begin_nested()`` methods.
"""
with conn.begin():
result = run_in_nested_transaction(conn, callback, max_retries, max_backoff, **kwargs)
if isinstance(result, ChainTransaction):
for transaction in result.transactions:
result.add_result(
run_in_nested_transaction(conn, transaction, max_retries, max_backoff, **kwargs)
)
return result
37 changes: 36 additions & 1 deletion test/test_run_transaction_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from sqlalchemy.testing import fixtures
from sqlalchemy.types import Integer
import threading
from sqlalchemy.orm import sessionmaker, scoped_session


from sqlalchemy_cockroachdb import run_transaction
from sqlalchemy_cockroachdb.transaction import ChainTransaction

meta = MetaData()

Expand All @@ -25,7 +28,9 @@ def setup_method(self, method):
)

def teardown_method(self, method):
meta.drop_all(testing.db)
session = scoped_session(sessionmaker(bind=testing.db))
session.query(account_table).delete()
session.commit()

def get_balances(self, conn):
"""Returns the balances of the two accounts as a list."""
Expand Down Expand Up @@ -134,3 +139,33 @@ def txn_body(conn):
with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body)
assert rs[0] == (1, 100)

def test_run_transaction_retry_with_nested(self):
def txn_body(conn):
rs = conn.execute(text("select acct, balance from account where acct = 1"))
conn.execute(text("select crdb_internal.force_retry('1s')"))
return [r for r in rs]

with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body, use_cockroach_restart=False)
assert rs[0] == (1, 100)

def test_run_chained_transaction(self):
def txn_body(conn):
# first transaction inserts
conn.execute(account_table.insert(), [dict(acct=99, balance=100)])
conn.execute(text("select crdb_internal.force_retry('1s')"))

def _get_val(s):
rs = s.execute(text("select acct, balance from account where acct = 99"))
return [r for r in rs]

# chain the get into a separate nested transaction, so that the value
# in the previous nested transaction is flushed and available
return ChainTransaction([lambda s: _get_val(s), lambda s: _get_val(s)])

with testing.db.connect() as conn:
rs = run_transaction(conn, txn_body, use_cockroach_restart=False)
assert len(rs.results) == 2
assert rs.results[0][0] == (99, 100)
assert rs.results[1][0] == (99, 100)