diff --git a/django_pev/utils/__init__.py b/django_pev/utils/__init__.py index d27cf83..f1d031d 100644 --- a/django_pev/utils/__init__.py +++ b/django_pev/utils/__init__.py @@ -12,6 +12,7 @@ from typing import Any import sqlglot +import sqlglot.errors import sqlglot.expressions as exp from django.db import connection, connections from django.db.backends.utils import CursorWrapper @@ -267,16 +268,19 @@ def explain( logger.debug(f"Captured {len(queries_after)} queries") for index, q in enumerate(queries_after): - result.queries.append( - Explain( - index=index, - duration=float(q["time"]), - sql=sqlglot.transpile(q["sql"], read="postgres", pretty=True)[0], - stack_trace=q["stack_trace"], - db_alias=db_alias, - fingerprint=generate_fingerprint(q["sql"]) or q["sql"], + try: + result.queries.append( + Explain( + index=index, + duration=float(q["time"]), + sql=sqlglot.transpile(q["sql"], read="postgres", pretty=True)[0], + stack_trace=q["stack_trace"], + db_alias=db_alias, + fingerprint=generate_fingerprint(q["sql"]) or q["sql"], + ) ) - ) + except sqlglot.errors.ParseError as e: + logger.warning(f"Error parsing query: {e}") def generate_fingerprint(sql_query: str) -> str | None: diff --git a/tests/test_explain.py b/tests/test_explain.py index c1010ab..e66db82 100644 --- a/tests/test_explain.py +++ b/tests/test_explain.py @@ -1,3 +1,4 @@ +from django.db import connection from django.test import TestCase from django_pev import explain @@ -53,3 +54,16 @@ def test_optimization_prompt(self): list(Student.objects.filter(name="1")) e.slowest.optimization_prompt(analyze=True) + + def test_savepoint_parsing_error_handling(self): + """Test that SAVEPOINT queries that can't be parsed by sqlglot are handled gracefully""" + with explain() as e: + # Execute SAVEPOINT commands that sqlglot can't parse + with connection.cursor() as cursor: + # SQL GLOT can parse + cursor.execute("SAVEPOINT test_savepoint") + + # SQL GLOT can't parse + cursor.execute("RELEASE SAVEPOINT test_savepoint") + + self.assertEqual(len(e.queries), 1)