Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 30 additions & 90 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
custom_serialize,
get_json_extract_field,
get_or_create,
insert_update_conflict,
)
from keep.api.core.dependencies import SINGLE_TENANT_UUID

Expand Down Expand Up @@ -166,7 +167,7 @@ def __convert_to_uuid(value: str, should_raise: bool = False) -> UUID | None:

def retry_on_db_error(f):
@retry(
exceptions=(OperationalError, IntegrityError, StaleDataError),
exceptions=(OperationalError, IntegrityError, StaleDataError, NoActiveSqlTransaction),
tries=3,
delay=0.1,
backoff=2,
Expand All @@ -177,7 +178,7 @@ def retry_on_db_error(f):
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except (OperationalError, IntegrityError, StaleDataError) as e:
except (OperationalError, IntegrityError, StaleDataError, NoActiveSqlTransaction) as e:

if hasattr(e, "session") and not e.session.is_active:
e.session.rollback()
Expand All @@ -187,6 +188,11 @@ def wrapper(*args, **kwargs):
"Deadlock detected, retrying transaction", extra={"error": str(e)}
)
raise # retry will catch this
elif "No active SQL transaction" in str(e):
logger.exception(
"No active SQL transaction detected, retrying transaction", extra={"error": str(e)}
)
raise # retry will catch this
else:
logger.exception(
f"Error while executing transaction during {f.__name__}",
Expand Down Expand Up @@ -5683,99 +5689,33 @@ def get_last_alert_by_fingerprint(
query = query.with_for_update()
return session.exec(query).first()


@retry_on_db_error
def set_last_alert(
tenant_id: str, alert: Alert, session: Optional[Session] = None, max_retries=3
tenant_id: str, alert: Alert, session: Optional[Session] = None
) -> None:
fingerprint = alert.fingerprint
logger.info(f"Setting last alert for `{fingerprint}`")
with existed_or_new_session(session) as session:
for attempt in range(max_retries):
logger.info(
f"Attempt {attempt} to set last alert for `{fingerprint}`",
extra={
"alert_id": alert.id,
"tenant_id": tenant_id,
"fingerprint": fingerprint,
},
)
try:
last_alert = get_last_alert_by_fingerprint(
tenant_id, fingerprint, session, for_update=True
)

# To prevent rare, but possible race condition
# For example if older alert failed to process
# and retried after new one
if last_alert and last_alert.timestamp.replace(
tzinfo=tz.UTC
) < alert.timestamp.replace(tzinfo=tz.UTC):

logger.info(
f"Update last alert for `{fingerprint}`: {last_alert.alert_id} -> {alert.id}",
extra={
"alert_id": alert.id,
"tenant_id": tenant_id,
"fingerprint": fingerprint,
},
)
last_alert.timestamp = alert.timestamp
last_alert.alert_id = alert.id
last_alert.alert_hash = alert.alert_hash
session.add(last_alert)

elif not last_alert:
logger.info(f"No last alert for `{fingerprint}`, creating new")
last_alert = LastAlert(
tenant_id=tenant_id,
fingerprint=alert.fingerprint,
timestamp=alert.timestamp,
first_timestamp=alert.timestamp,
alert_id=alert.id,
alert_hash=alert.alert_hash,
)

session.add(last_alert)
session.commit()
break
except OperationalError as ex:
if "no such savepoint" in ex.args[0]:
logger.info(
f"No such savepoint while updating lastalert for `{fingerprint}`, retry #{attempt}"
)
session.rollback()
if attempt >= max_retries:
raise ex
continue

if "Deadlock found" in ex.args[0]:
logger.info(
f"Deadlock found while updating lastalert for `{fingerprint}`, retry #{attempt}"
)
session.rollback()
if attempt >= max_retries:
raise ex
continue
except NoActiveSqlTransaction:
logger.exception(
f"No active sql transaction while updating lastalert for `{fingerprint}`, retry #{attempt}",
extra={
"alert_id": alert.id,
"tenant_id": tenant_id,
"fingerprint": fingerprint,
},
)
continue
logger.debug(
f"Successfully updated lastalert for `{fingerprint}`",
extra={
"alert_id": alert.id,
"tenant_id": tenant_id,
"fingerprint": fingerprint,
},
)
# break the retry loop
break
insert_update_conflict(LastAlert, session, data_to_insert = {
"tenant_id":tenant_id,
"fingerprint": alert.fingerprint,
"timestamp": alert.timestamp,
"first_timestamp": alert.timestamp,
"alert_id": alert.id,
"alert_hash": alert.alert_hash,
}, data_to_update ={
"timestamp": alert.timestamp,
"alert_id": alert.id,
"alert_hash": alert.alert_hash
}, update_newer=True)
logger.debug(
f"Successfully updated lastalert for `{fingerprint}`",
extra={
"alert_id": alert.id,
"tenant_id": tenant_id,
"fingerprint": fingerprint,
},
)

def set_maintenance_windows_trace(alert: Alert, maintenance_w: MaintenanceWindowRule, session: Optional[Session] = None):
mw_id = str(maintenance_w.id)
Expand Down
43 changes: 43 additions & 0 deletions keep/api/core/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from sqlalchemy.sql.ddl import CreateColumn
from sqlalchemy.sql.functions import GenericFunction
from sqlmodel import Session, SQLModel, create_engine, select
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy import case

# This import is required to create the tables
from keep.api.consts import RUNNING_IN_CLOUD_RUN
Expand Down Expand Up @@ -199,6 +203,45 @@ def get_aggreated_field(session: Session, column_name: str, alias: str):
return func.array_agg(column_name).label(alias)


def insert_update_conflict(table: SQLModel, session: Session, data_to_insert: dict, data_to_update: dict, update_newer: bool):
"""
Performs an upsert (insert or update on conflict) operation on the given table.
Args:
table (SQLModel): The table to perform the upsert on.
session (Session): The SQLModel session.
data_to_insert (dict): The data to insert.
data_to_update (dict): The data to update if a conflict occurs.
update_newer (bool): If True, update only if existing timestamp is older than new one.
"""

if session.bind.dialect.name == "postgresql":
query = pg_insert(table).values(data_to_insert)
query = query.on_conflict_do_update(
index_elements=[col.name for col in table.__table__.primary_key.columns],
set_=data_to_update,
where=(table.timestamp < query.excluded.timestamp) if update_newer else None
)
elif session.bind.dialect.name == "mysql":
query = mysql_insert(table).values(data_to_insert)
if update_newer:
data_to_update = {
k: case((table.timestamp < query.inserted.timestamp, v), else_=getattr(table, k))
for k, v in data_to_update.items()
}
query = query.on_duplicate_key_update(data_to_update)
elif session.bind.dialect.name == "sqlite":
query = sqlite_insert(table).values(data_to_insert)
query = query.on_conflict_do_update(
index_elements=[col.name for col in table.__table__.primary_key.columns],
set_=data_to_update,
where=(table.timestamp < query.excluded.timestamp) if update_newer else None
)
else:
raise NotImplementedError(f"UPSERT not supported for {session.bind.dialect.name}")

session.exec(query)
session.commit()

class json_table(GenericFunction):
inherit_cache = True

Expand Down
Loading