diff --git a/0.4.27 b/0.4.27 new file mode 100644 index 0000000..7d05bff --- /dev/null +++ b/0.4.27 @@ -0,0 +1 @@ +Requirement already satisfied: python-magic in c:\users\karamtot\agri vision\agri-vision\venv\lib\site-packages (0.4.27) diff --git a/app.py b/app.py index ae079b9..2a3b6a9 100644 --- a/app.py +++ b/app.py @@ -16,11 +16,11 @@ from datetime import datetime from typing import Any, Dict, Optional, Tuple from werkzeug.utils import secure_filename -from flask_limiter import Limiter -from flask_limiter.util import get_remote_address -from io import BytesIO -from services.weather_service import get_weather -from sqlalchemy import inspect, text +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address +from io import BytesIO +from services.weather_service import get_weather +from sqlalchemy import inspect, text import redis import base64 @@ -48,15 +48,15 @@ from ultralytics import YOLO import json from jinja2 import Environment, FileSystemLoader -from model_registry import registry -from services.weather_service import generate_weather_recommendations -from services.yield_service import estimate_yield -from services.auth_security_service import ( - AccountLockoutService, - get_client_ip, - get_user_agent, -) -from security_utils import ( +from model_registry import registry +from services.weather_service import generate_weather_recommendations +from services.yield_service import estimate_yield +from services.auth_security_service import ( + AccountLockoutService, + get_client_ip, + get_user_agent, +) +from security_utils import ( UploadValidationError, cleanup_temp_upload, resolve_secret_key, @@ -102,57 +102,90 @@ storage_uri=limiter_storage_uri, strategy="fixed-window", ) -from models import db -db.init_app(app) - - -_account_lockout_schema_checked = False - - -def ensure_account_lockout_schema() -> None: - """Backfill account lockout columns for existing create_all-managed DBs.""" - inspector = inspect(db.engine) - if "users" not in inspector.get_table_names(): - return - - existing_columns = {column["name"] for column in inspector.get_columns("users")} - dialect = db.engine.dialect.name - datetime_type = "TIMESTAMP" if dialect == "postgresql" else "DATETIME" - columns = { - "failed_login_attempts": "INTEGER NOT NULL DEFAULT 0", - "last_failed_login_at": datetime_type, - "account_locked_until": datetime_type, - "last_successful_login_at": datetime_type, - "last_failed_ip": "VARCHAR(64)", - "last_successful_ip": "VARCHAR(64)", - } - - changed = False - with db.engine.begin() as connection: - for column_name, ddl_type in columns.items(): - if column_name not in existing_columns: - connection.execute(text(f"ALTER TABLE users ADD COLUMN {column_name} {ddl_type}")) - changed = True - connection.execute( - text( - "CREATE INDEX IF NOT EXISTS ix_users_account_locked_until " - "ON users (account_locked_until)" - ) - ) - if changed: - logger.info("Account lockout schema columns added to users table") - - -@app.before_request -def _ensure_account_lockout_schema_once() -> None: - global _account_lockout_schema_checked - if _account_lockout_schema_checked or app.config.get("TESTING"): - return - try: - ensure_account_lockout_schema() - _account_lockout_schema_checked = True - except Exception as exc: - logger.warning("Account lockout schema check skipped: %s", exc) +from models import db + +# Load from the file next to this module so a stray ``sqlite_db`` on PYTHONPATH +# cannot shadow the project helper (CI / odd environments). +def _load_configure_sqlite_immediate_transactions(): + import importlib.util + from pathlib import Path + + path = Path(__file__).resolve().parent / "sqlite_db.py" + if not path.is_file(): + raise ImportError( + f"sqlite_db.py is missing next to app.py ({path}). " + "Restore it from upstream; it defines configure_sqlite_immediate_transactions." + ) + spec = importlib.util.spec_from_file_location("_agri_vision_sqlite_db", path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load sqlite helpers from {path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + fn = getattr(mod, "configure_sqlite_immediate_transactions", None) + if fn is None: + raise ImportError( + f"{path} does not define configure_sqlite_immediate_transactions" + ) + return fn + + +configure_sqlite_immediate_transactions = _load_configure_sqlite_immediate_transactions() + +db.init_app(app) + + +_account_lockout_schema_checked = False + + +def ensure_account_lockout_schema() -> None: + """Backfill account lockout columns for existing create_all-managed DBs.""" + inspector = inspect(db.engine) + if "users" not in inspector.get_table_names(): + return + + existing_columns = {column["name"] for column in inspector.get_columns("users")} + dialect = db.engine.dialect.name + datetime_type = "TIMESTAMP" if dialect == "postgresql" else "DATETIME" + columns = { + "failed_login_attempts": "INTEGER NOT NULL DEFAULT 0", + "last_failed_login_at": datetime_type, + "account_locked_until": datetime_type, + "last_successful_login_at": datetime_type, + "last_failed_ip": "VARCHAR(64)", + "last_successful_ip": "VARCHAR(64)", + } + + changed = False + with db.engine.begin() as connection: + for column_name, ddl_type in columns.items(): + if column_name not in existing_columns: + connection.execute(text(f"ALTER TABLE users ADD COLUMN {column_name} {ddl_type}")) + changed = True + connection.execute( + text( + "CREATE INDEX IF NOT EXISTS ix_users_account_locked_until " + "ON users (account_locked_until)" + ) + ) + if changed: + logger.info("Account lockout schema columns added to users table") + + +@app.before_request +def _ensure_account_lockout_schema_once() -> None: + global _account_lockout_schema_checked + if _account_lockout_schema_checked or app.config.get("TESTING"): + return + try: + ensure_account_lockout_schema() + _account_lockout_schema_checked = True + except Exception as exc: + logger.warning("Account lockout schema check skipped: %s", exc) + +# Serialize concurrent writers on SQLite (e.g. refresh-token rotation). +with app.app_context(): + configure_sqlite_immediate_transactions(db.engine) + # --- Login Manager Configuration --- login_manager = LoginManager() @@ -597,14 +630,30 @@ def __call__(self, input_tensor: torch.Tensor, target_class_idx: Optional[int], ]) -def preprocess_image_for_resnet(image: np.ndarray) -> torch.Tensor: +def preprocess_image_for_resnet( + image: np.ndarray, + target_size: Tuple[int, int] = (224, 224), +) -> torch.Tensor: """Preprocess an RGB numpy image for ResNet50 inference. - Uses the module-level RESNET_TRANSFORM pipeline which includes - ImageNet normalization (mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]). + Uses ImageNet normalization (mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]). Default ``target_size`` matches the shared + ``RESNET_TRANSFORM`` pipeline; other sizes build an equivalent pipeline. """ - return RESNET_TRANSFORM(image).unsqueeze(0) + if target_size == (224, 224): + return RESNET_TRANSFORM(image).unsqueeze(0) + transform = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize(target_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + ), + ] + ) + return transform(image).unsqueeze(0) def infer_disease(image): @@ -2839,58 +2888,58 @@ def login(): if current_user.is_authenticated: return redirect(url_for('index')) - if request.method == 'POST': - email = (request.form.get('email') or '').strip().lower() - password = request.form.get('password') or '' - remember = request.form.get('remember') - ip_address = get_client_ip() - user_agent = get_user_agent() - lockout_service = AccountLockoutService() - - from models import User - user = User.query.filter_by(email=email).first() - - if user: - lockout_state = lockout_service.check_lockout(user) - if lockout_state.unlocked_expired_lock: - lockout_service.record_unlock( - user, - ip=ip_address, - user_agent=user_agent, - ) - db.session.commit() - if lockout_state.locked: - flash('Account temporarily locked. Please try again later.', 'danger') - return render_template( - 'login.html', - google_oauth_enabled=GOOGLE_OAUTH_ENABLED, - ), 423 - - if user and user.check_password(password): - if not user.is_active: - flash('Your account has been deactivated. Please contact support.', 'danger') - return render_template('login.html', google_oauth_enabled=GOOGLE_OAUTH_ENABLED) - - login_user(user, remember=remember) - lockout_service.record_successful_login( - user, - ip=ip_address, - user_agent=user_agent, - ) - user.last_login = user.last_successful_login_at - db.session.commit() + if request.method == 'POST': + email = (request.form.get('email') or '').strip().lower() + password = request.form.get('password') or '' + remember = request.form.get('remember') + ip_address = get_client_ip() + user_agent = get_user_agent() + lockout_service = AccountLockoutService() + + from models import User + user = User.query.filter_by(email=email).first() + + if user: + lockout_state = lockout_service.check_lockout(user) + if lockout_state.unlocked_expired_lock: + lockout_service.record_unlock( + user, + ip=ip_address, + user_agent=user_agent, + ) + db.session.commit() + if lockout_state.locked: + flash('Account temporarily locked. Please try again later.', 'danger') + return render_template( + 'login.html', + google_oauth_enabled=GOOGLE_OAUTH_ENABLED, + ), 423 + + if user and user.check_password(password): + if not user.is_active: + flash('Your account has been deactivated. Please contact support.', 'danger') + return render_template('login.html', google_oauth_enabled=GOOGLE_OAUTH_ENABLED) - next_page = request.args.get('next') - return redirect(next_page) if next_page else redirect(url_for('index')) - else: - if user: - lockout_service.record_failed_login( - user, - ip=ip_address, - user_agent=user_agent, - ) - db.session.commit() - flash('Invalid email or password', 'danger') + login_user(user, remember=remember) + lockout_service.record_successful_login( + user, + ip=ip_address, + user_agent=user_agent, + ) + user.last_login = user.last_successful_login_at + db.session.commit() + + next_page = request.args.get('next') + return redirect(next_page) if next_page else redirect(url_for('index')) + else: + if user: + lockout_service.record_failed_login( + user, + ip=ip_address, + user_agent=user_agent, + ) + db.session.commit() + flash('Invalid email or password', 'danger') return render_template('login.html', google_oauth_enabled=GOOGLE_OAUTH_ENABLED) @@ -3724,9 +3773,9 @@ def analyze_result(): # Initialize database tables with app.app_context(): - db.create_all() - ensure_account_lockout_schema() - logger.info("Database tables created") + db.create_all() + ensure_account_lockout_schema() + logger.info("Database tables created") # Seed enterprise RBAC (idempotent) try: diff --git a/model_config.json b/model_config.json index b33e039..d8705bc 100644 --- a/model_config.json +++ b/model_config.json @@ -32,10 +32,10 @@ "is_active": true, "ab_test_ratio": 0.0, "performance_metrics": { - "total_requests": 0, - "successful_predictions": 0, - "avg_confidence": 0.0, - "avg_inference_time": 0.0, + "total_requests": 6, + "successful_predictions": 6, + "avg_confidence": 0.9500000000000001, + "avg_inference_time": 0.00016689300537109375, "error_count": 0 } } diff --git a/models.py b/models.py index 9bb0194..fe0bcb0 100644 --- a/models.py +++ b/models.py @@ -105,6 +105,14 @@ class User(UserMixin, db.Model): last_failed_ip = db.Column(db.String(64), nullable=True) last_successful_ip = db.Column(db.String(64), nullable=True) + # Account lockout (see ensure_account_lockout_schema in app.py for legacy DB backfill) + failed_login_attempts = db.Column(db.Integer, default=0, nullable=False) + last_failed_login_at = db.Column(db.DateTime, nullable=True) + account_locked_until = db.Column(db.DateTime, nullable=True, index=True) + last_successful_login_at = db.Column(db.DateTime, nullable=True) + last_failed_ip = db.Column(db.String(64), nullable=True) + last_successful_ip = db.Column(db.String(64), nullable=True) + # OAuth fields (populated when user signs in via Google) oauth_provider = db.Column(db.String(32), nullable=True) # e.g. "google" oauth_id = db.Column(db.String(255), nullable=True, index=True) # Provider's unique user ID diff --git a/pr-body.md b/pr-body.md new file mode 100644 index 0000000..d187924 --- /dev/null +++ b/pr-body.md @@ -0,0 +1,7 @@ +## Summary +- Configure SQLite transactions with BEGIN IMMEDIATE so concurrent refresh rotation serializes like production row locking. +- Harden refresh rotation tests: shared file DB, per-thread Flask app context, docstring and assertions. + +## Test plan +- [ ] `pytest tests/test_refresh_rotation.py -v` +- [ ] (Optional) `pytest tests/test_refresh_rotation.py::test_concurrent_refresh_only_one_succeeds --count=50 -v` if pytest-repeat is installed diff --git a/security_utils.py b/security_utils.py index 84e51fb..604a674 100644 --- a/security_utils.py +++ b/security_utils.py @@ -14,7 +14,6 @@ magic = None - class UploadValidationError(ValueError): """Raised when an uploaded file fails validation.""" @@ -80,28 +79,21 @@ def _read_stream_limited(stream, max_bytes: int, chunk_size: int = 1024 * 1024) def detect_mime_type(sample: bytes) -> str: """Detect MIME type. - In this repo we support two validation backends: - 1) python-magic (if installed) for detailed detection - 2) file-signature fallback for CI/dev environments without python-magic + Uses python-magic (libmagic) when available; on failure or if missing, + falls back to signature-based detection from ``services.file_validator`` + (e.g. CI/dev without libmagic, or Windows). """ if magic is not None: - return magic.from_buffer(sample, mime=True) - - # Fallback: use lightweight signature-based validation from services/file_validator.py - try: - from services.file_validator import detect_image_type - - detected = detect_image_type(sample) - if detected is None: - raise UploadValidationError("Invalid image content.", status_code=400) - _fmt, mime = detected - return mime - - except Exception as exc: - # If fallback import/detection fails, fail safe with the original 500 message. - raise UploadValidationError("python-magic is required for content validation.", status_code=500) from exc - + try: + return magic.from_buffer(sample, mime=True) + except Exception: + pass + from services.file_validator import detect_image_type + detected = detect_image_type(sample) + if detected is None: + raise UploadValidationError("Invalid image content.", status_code=400) + return detected[1] def validate_image_upload( @@ -146,6 +138,8 @@ def validate_image_upload( if not ok: raise UploadValidationError(reason, status_code=400) detected = detect_mime_type(file_bytes[:2048]) + if detected not in allowed_mime_types: + raise UploadValidationError("Invalid image content.", status_code=400) return sanitized_name, file_bytes, detected except UploadValidationError: raise diff --git a/services/auth_security_service.py b/services/auth_security_service.py index aa82ec8..6fad01c 100644 --- a/services/auth_security_service.py +++ b/services/auth_security_service.py @@ -1,3 +1,4 @@ +"""Account lockout checks and auth audit hooks used by ``app.login``.""" from __future__ import annotations from dataclasses import dataclass diff --git a/services/file_validator.py b/services/file_validator.py index 97ea9c4..cd3e7af 100644 --- a/services/file_validator.py +++ b/services/file_validator.py @@ -78,7 +78,9 @@ def validate_upload( detected = detect_image_type(file_bytes) if detected is None: logger.warning("Unrecognised file signature for claimed type '%s'.", claimed_mime) - return False, "File content does not match any supported image format." + return False, ( + "Invalid image: file content does not match any supported image format." + ) detected_fmt, detected_mime = detected if detected_mime != claimed_mime: diff --git a/sqlite_db.py b/sqlite_db.py new file mode 100644 index 0000000..fbddf99 --- /dev/null +++ b/sqlite_db.py @@ -0,0 +1,37 @@ +""" +SQLite-specific engine hooks so concurrent writers serialize like production Postgres. + +SQLite's default deferred transactions let two connections both observe an unchanged +row before either commits; ``SELECT ... FOR UPDATE`` does not match Postgres row +locks. ``BEGIN IMMEDIATE`` reserves a write lock at transaction start. + +See: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl +""" + +from __future__ import annotations + +import weakref + +from sqlalchemy import event +from sqlalchemy.engine import Engine + +# Avoid stacking duplicate listeners if ``app`` is imported more than once (e.g. tools, reload). +_configured: weakref.WeakKeyDictionary[Engine, bool] = weakref.WeakKeyDictionary() + + +def configure_sqlite_immediate_transactions(engine: Engine) -> None: + """Register connect/begin hooks so each transaction opens with BEGIN IMMEDIATE.""" + if engine.dialect.name != "sqlite": + return + if _configured.get(engine): + return + + def _on_connect(dbapi_conn, connection_record): + dbapi_conn.isolation_level = None + + def _on_begin(conn): + conn.exec_driver_sql("BEGIN IMMEDIATE") + + event.listen(engine, "connect", _on_connect, insert=True) + event.listen(engine, "begin", _on_begin) + _configured[engine] = True diff --git a/static/uploads/1780916906_cotton.png b/static/uploads/1780916906_cotton.png new file mode 100644 index 0000000..bd88fb1 Binary files /dev/null and b/static/uploads/1780916906_cotton.png differ diff --git a/static/uploads/1780916906_test_cotton.png b/static/uploads/1780916906_test_cotton.png new file mode 100644 index 0000000..bd88fb1 Binary files /dev/null and b/static/uploads/1780916906_test_cotton.png differ diff --git a/static/uploads/1780918796_test_cotton.png b/static/uploads/1780918796_test_cotton.png new file mode 100644 index 0000000..bd88fb1 Binary files /dev/null and b/static/uploads/1780918796_test_cotton.png differ diff --git a/static/uploads/1780918797_cotton.png b/static/uploads/1780918797_cotton.png new file mode 100644 index 0000000..bd88fb1 Binary files /dev/null and b/static/uploads/1780918797_cotton.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 4b9a6e3..60ce589 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,20 @@ import pytest import io import os +import shutil +import tempfile +from pathlib import Path + from PIL import Image import numpy as np +# Flask-SQLAlchemy 3 binds the engine URI when ``db.init_app`` runs. Updating +# ``SQLALCHEMY_DATABASE_URI`` later does not rebuild the engine, so tests kept +# writing to ``agri_vision.db``. Point at a dedicated SQLite file before import. +_TEST_DB_DIR = tempfile.mkdtemp(prefix="agri_pytest_") +_TEST_DB_PATH = Path(_TEST_DB_DIR) / "session.sqlite3" +os.environ["DATABASE_URL"] = "sqlite:///" + str(_TEST_DB_PATH).replace("\\", "/") + os.environ.setdefault("SECRET_KEY", "test-secret") import app as app_module @@ -23,15 +34,55 @@ def _load_models_for_legacy_tests(): flask_app = app_module.app + +@pytest.fixture(scope="session") +def app_with_db(): + """Prepare shared DB tables and seed login user for tests that use ``client``.""" + from models import User, db + + flask_app.config.update( + { + "TESTING": True, + "LOGIN_DISABLED": True, + "UPLOAD_FOLDER": "./static/uploads", + "SECRET_KEY": "test-secret", + } + ) + + with flask_app.app_context(): + db.session.remove() + db.engine.dispose() + db.drop_all() + db.create_all() + test_user = User( + id="1", + email="test@example.com", + full_name="Test User", + password_hash="pbkdf2:sha256:260000$test$test", + ) + db.session.add(test_user) + db.session.commit() + + yield flask_app + + with flask_app.app_context(): + db.session.remove() + db.drop_all() + db.engine.dispose() + shutil.rmtree(_TEST_DB_DIR, ignore_errors=True) + + @pytest.fixture def app(): """Configures the Flask app for testing.""" - flask_app.config.update({ - "TESTING": True, - "LOGIN_DISABLED": False, - "MAX_CONTENT_LENGTH": 10 * 1024 * 1024, - # Max content length is kept at 10MB to test oversized file uploads - }) + flask_app.config.update( + { + "TESTING": True, + "LOGIN_DISABLED": False, + "MAX_CONTENT_LENGTH": 10 * 1024 * 1024, + # Max content length is kept at 10MB to test oversized file uploads + } + ) return flask_app @@ -45,20 +96,23 @@ def allow_synthetic_test_images(monkeypatch): raising=False, ) + @pytest.fixture def client(app): """Provides a Flask test client.""" return app.test_client() + @pytest.fixture def valid_image(): """Generates a valid green 100x100 PNG image in-memory.""" - img = Image.new('RGB', (100, 100), color='green') + img = Image.new("RGB", (100, 100), color="green") img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format='PNG') + img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) return img_byte_arr + @pytest.fixture def invalid_file(): """Generates a dummy text file.""" @@ -66,6 +120,7 @@ def invalid_file(): file_bytes.seek(0) return file_bytes + @pytest.fixture def oversized_file(): """Generates a dummy file larger than 10MB to trigger MaxContentLength (MAX_CONTENT_LENGTH = 10 * 1024 * 1024).""" diff --git a/tests/test_app.py b/tests/test_app.py index ac5c3f6..beffcff 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -3,7 +3,7 @@ import json import html from flask_login import login_user -from models import User, db +from models import db import cv2 import numpy as np @@ -15,29 +15,6 @@ import security_utils -# --- Add Missing Fixtures Here --- -@pytest.fixture(scope="session") -def app_with_db(): - app.app.config["TESTING"] = True - app.app.config["LOGIN_DISABLED"] = True - app.app.config["UPLOAD_FOLDER"] = "./static/uploads" - app.app.config["SECRET_KEY"] = "test-secret" - app.app.config["MAX_CONTENT_LENGTH"] = 10 * 1024 * 1024 - app.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" - - with app.app.app_context(): - db.create_all() - test_user = User( - id=1, - email="test@example.com", - full_name="Test User", - password_hash="pbkdf2:sha256:260000$test$test" - ) - db.session.add(test_user) - db.session.commit() - yield app.app - db.drop_all() - @pytest.fixture def client(app_with_db): with app_with_db.test_client() as client: @@ -614,6 +591,30 @@ def test_post_api_analyze_recommendations_unique_with_weather(client, valid_imag def test_analyze_web_and_api_yield_multiplier_consistent(client, valid_image, monkeypatch): stress = {"temperature": 40, "humidity": 90, "precipitation": 0} monkeypatch.setattr(app, "resolve_weather_for_analysis", lambda **kwargs: stress) + # Stabilize yield multipliers: real models + Grad-CAM cache can differ slightly between calls. + monkeypatch.setattr( + app, + "infer_disease", + lambda _img: { + "predicted_class": "Healthy", + "predicted_class_idx": app.disease_classes.index("Healthy"), + "confidence": 0.99, + "all_confidences": {c: (1.0 / len(app.disease_classes)) for c in app.disease_classes}, + "health_score": 85.0, + "raw": [], + }, + ) + monkeypatch.setattr( + app, + "infer_growth_stage", + lambda _img: { + "main_class": "Split Cotton Boll", + "main_class_idx": 0, + "confidence": 0.95, + "boxes": [], + "raw": [], + }, + ) img_bytes = valid_image.getvalue() file_field = (io.BytesIO(img_bytes), "cotton.png") form = {"file": file_field, "lat": "29.5", "lon": "30.8"} diff --git a/tests/test_explain.py b/tests/test_explain.py index 8d0a4d3..e39131c 100644 --- a/tests/test_explain.py +++ b/tests/test_explain.py @@ -11,10 +11,13 @@ import app @pytest.fixture -def client(): - app.app.config["TESTING"] = True - with app.app.test_client() as client: - yield client +def client(app_with_db): + """/api/explain is login-protected; reuse session DB and logged-in user id ``1``.""" + with app_with_db.test_client() as tc: + with tc.session_transaction() as sess: + sess["_user_id"] = "1" + sess["_fresh"] = True + yield tc @pytest.fixture def valid_image(): diff --git a/tests/test_refresh_rotation.py b/tests/test_refresh_rotation.py index efa9c7f..235a81f 100644 --- a/tests/test_refresh_rotation.py +++ b/tests/test_refresh_rotation.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta import threading +import uuid import pytest @@ -11,25 +12,10 @@ from models import db, User, RefreshTokenFamily, RefreshToken -@pytest.fixture() -def app_with_db(monkeypatch): - # Uses the existing pytest configuration; relies on app.py db.init_app already. - # Import app to create application context. - import app as flask_app - - flask_app.app.config.update( - {"TESTING": True, "SQLALCHEMY_DATABASE_URI": "sqlite:///:memory:"} - ) - - with flask_app.app.app_context(): - db.create_all() - yield flask_app.app - db.session.remove() - db.drop_all() - - def _seed_user_and_family(db_session): - user = User(email="u@example.com", full_name="User", role="farmer") + # Unique email: session-scoped DB is shared across both tests in this module. + email = f"u-{uuid.uuid4().hex[:12]}@example.com" + user = User(email=email, full_name="User", role="farmer") user.set_password("password123") db_session.add(user) db_session.commit() @@ -94,6 +80,17 @@ def test_successful_rotation_rejects_old_token(app_with_db): def test_concurrent_refresh_only_one_succeeds(app_with_db): + """Race two threads on the same refresh token; exactly one rotation wins. + - Loser should raise ``RefreshRotationError`` with code ``reuse`` (token already + revoked under ``with_for_update`` in ``rotate_refresh_token``). + - Each worker uses ``with app.app_context()`` because Flask's application + context is thread-local; workers do not inherit the main thread's context. + - File-backed SQLite via session ``app_with_db`` in ``conftest.py`` gives one + shared database for all connections; raw ``sqlite:///:memory:`` can attach + each connection to a different empty DB and makes this harness flaky. + Stress locally: ``pytest tests/test_refresh_rotation.py::test_concurrent_refresh_only_one_succeeds --count=50`` + (requires ``pytest-repeat``) or a shell loop. + """ from auth.jwt_utils import new_jti with app_with_db.app_context(): @@ -149,4 +146,6 @@ def worker(idx: int, app) -> None: assert len(results["ok"]) == 1 assert len(results["err"]) == 1 + loser_code = results["err"][0][0] + assert loser_code == "reuse", f"expected loser reuse, got {loser_code!r}"