diff --git a/hub-server/database.py b/hub-server/database.py index 404b541..d76b4b1 100755 --- a/hub-server/database.py +++ b/hub-server/database.py @@ -5,7 +5,7 @@ from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Optional, Dict, Union, Generator +from typing import Optional, Dict, Union from config import ( SQLITE_TIMEOUT, POWER_FACTOR, MAINS_VOLTAGE, ENERGY_MONTHLY_RESET, SQLITE_RETRIES, SQLITE_RETRY_DELAY ) @@ -38,9 +38,11 @@ def __init__(self, db_path: Union[str, Path]): logging.info(f"Database initialized at path: {self.db_path}") def _connect(self): - if self._conn is None: - if not self.db_path.parent.exists(): - raise RuntimeError("Database directory missing") + if self._conn is not None: + return + + if not self.db_path.parent.exists(): + raise RuntimeError("Database directory missing") self._conn = sqlite3.connect( self.db_path, @@ -57,27 +59,29 @@ def _get_connection(self): @contextmanager - def __get_connection_cm(self) -> Generator[sqlite3.Connection, None, None]: + def __get_connection_cm(self): """ Context manager providing a thread-safe single connection with auto-reconnect. """ for attempt in range(1, SQLITE_RETRIES + 1): try: - self._connect() with self._conn_lock: + self._connect() yield self._conn return except (sqlite3.OperationalError, sqlite3.DatabaseError) as e: logging.warning(f"DB connection error (attempt {attempt}/{SQLITE_RETRIES}): {e}") - if self._conn: - try: - self._conn.close() - except Exception: - pass - self._conn = None + with self._conn_lock: + if self._conn: + try: + self._conn.close() + except Exception: + pass + self._conn = None if attempt < SQLITE_RETRIES: time.sleep(SQLITE_RETRY_DELAY) + raise RuntimeError("Could not acquire DB connection after retries") @@ -98,6 +102,7 @@ def setup(self) -> None: cursor = conn.cursor() # Enable WAL + busy timeout + self._conn.execute("PRAGMA synchronous=NORMAL;") cursor.execute("PRAGMA journal_mode=WAL;") cursor.execute("PRAGMA busy_timeout = 5000;") @@ -293,7 +298,7 @@ def truncate_old_data(self, months: int) -> int: conn.commit() # Reclaim space - cursor.execute("VACUUM") + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") if deleted_count > 0: logging.info(f"Truncated {deleted_count} old records (older than {cutoff_date.strftime('%Y-%m-%d')})")