diff --git a/src/django_otp_webauthn/helpers.py b/src/django_otp_webauthn/helpers.py index be58e5d..c2eb459 100644 --- a/src/django_otp_webauthn/helpers.py +++ b/src/django_otp_webauthn/helpers.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import hashlib import json -from typing import Optional from django.contrib.auth import get_user_model -from django.contrib.auth.models import AbstractUser +from django.contrib.auth.base_user import AbstractBaseUser from django.http import HttpRequest from webauthn import ( base64url_to_bytes, @@ -174,7 +175,7 @@ def get_attestation_conveyance_preference(self) -> AttestationConveyancePreferen def get_authenticator_attachment_preference( self, - ) -> Optional[AuthenticatorAttachment]: + ) -> AuthenticatorAttachment | None: """Get the authenticator attachment preference. By default, this is set to None, which means we don't have a preference. @@ -184,7 +185,7 @@ def get_authenticator_attachment_preference( """ return None - def get_credential_display_name(self, user: AbstractUser) -> str: + def get_credential_display_name(self, user: AbstractBaseUser) -> str: """Get the display name for the credential. This is used to display a name during registration and authentication. @@ -193,12 +194,12 @@ def get_credential_display_name(self, user: AbstractUser) -> str: User.get_username(). """ - if user.get_full_name(): + if hasattr(user, "get_full_name"): return f"{user.get_full_name()} ({user.get_username()})" return user.get_username() - def get_credential_name(self, user: AbstractUser) -> str: + def get_credential_name(self, user: AbstractBaseUser) -> str: """Get the name for the credential. This is used to display the user's name during registration and @@ -208,7 +209,7 @@ def get_credential_name(self, user: AbstractUser) -> str: """ return user.get_username() - def get_unique_anonymous_user_id(self, user: AbstractUser) -> bytes: + def get_unique_anonymous_user_id(self, user: AbstractBaseUser) -> bytes: """Get a unique identifier for the user to use during WebAuthn ceremonies. It must be a unique byte sequence no longer than 64 bytes. @@ -239,7 +240,7 @@ def get_unique_anonymous_user_id(self, user: AbstractUser) -> bytes: # bytes instead. return hashlib.sha256(bytes(user.pk)).digest() - def get_user_entity(self, user: AbstractUser) -> PublicKeyCredentialUserEntity: + def get_user_entity(self, user: AbstractBaseUser) -> PublicKeyCredentialUserEntity: """Get information about the user account a credential is being registered for.""" return PublicKeyCredentialUserEntity( id=self.get_unique_anonymous_user_id(user), @@ -270,7 +271,7 @@ def get_supported_key_algorithms(self) -> list[COSEAlgorithmIdentifier] | None: algorithms = [COSEAlgorithmIdentifier(a) for a in raw_algorithms if a in COSEAlgorithmIdentifier] return algorithms - def get_generate_registration_options_kwargs(self, *, user: AbstractUser) -> dict: + def get_generate_registration_options_kwargs(self, *, user: AbstractBaseUser) -> dict: """Get the keyword arguments to pass to `webauthn.generate_registration_options`.""" challenge = self.generate_challenge() rp = self.get_relying_party() @@ -323,7 +324,7 @@ def get_registration_state(self, creation_options: dict) -> dict: == UserVerificationRequirement.REQUIRED.value, } - def register_begin(self, user: AbstractUser) -> tuple[dict, dict]: + def register_begin(self, user: AbstractBaseUser) -> tuple[dict, dict]: """Begin the registration process.""" kwargs = self.get_generate_registration_options_kwargs(user=user) @@ -349,7 +350,7 @@ def get_allowed_origins(self) -> list[str]: origins = app_settings.OTP_WEBAUTHN_ALLOWED_ORIGINS return origins - def register_complete(self, user: AbstractUser, state: dict, data: dict): + def register_complete(self, user: AbstractBaseUser, state: dict, data: dict): """Complete the registration process.""" credential = parse_registration_credential_json(data) @@ -377,7 +378,7 @@ def register_complete(self, user: AbstractUser, state: dict, data: dict): self.create_attestation(device, response.attestation_object, credential.response.client_data_json) return device - def _check_discoverable(self, original_data: dict) -> Optional[bool]: + def _check_discoverable(self, original_data: dict) -> bool | None: """Check the clientExtensionResults to determine if the credential was created as discoverable. @@ -397,7 +398,7 @@ def _check_discoverable(self, original_data: dict) -> Optional[bool]: def create_credential( self, - user: AbstractUser, + user: AbstractBaseUser, response: VerifiedRegistration, parsed_credential: RegistrationCredential, original_data: dict, @@ -454,7 +455,7 @@ def get_authentication_extensions(self) -> dict: return {} def get_generate_authentication_options_kwargs( - self, *, user: Optional[AbstractUser] = None, require_user_verification: bool + self, *, user: AbstractBaseUser | None = None, require_user_verification: bool ) -> dict: """Get the keyword arguments to pass to `webauth.generate_authentication_options`.""" @@ -484,7 +485,7 @@ def get_authentication_state(self, options: dict) -> dict: def authenticate_begin( self, - user: Optional[AbstractUser] = None, + user: AbstractBaseUser | None = None, require_user_verification: bool = True, ): """Begin the authentication process.""" @@ -507,7 +508,7 @@ def authenticate_begin( state = self.get_authentication_state(data) return data, state - def authenticate_complete(self, user: Optional[AbstractUser], state: dict, data: dict): + def authenticate_complete(self, user: AbstractBaseUser | None, state: dict, data: dict): """Complete the authentication process.""" credential = parse_authentication_credential_json(data) diff --git a/src/django_otp_webauthn/models.py b/src/django_otp_webauthn/models.py index 2adcb37..2578517 100644 --- a/src/django_otp_webauthn/models.py +++ b/src/django_otp_webauthn/models.py @@ -1,7 +1,7 @@ import hashlib from django.contrib.auth import get_user_model -from django.contrib.auth.models import AbstractUser +from django.contrib.auth.base_user import AbstractBaseUser from django.db import models from django.db.models import QuerySet from django.http import HttpRequest @@ -324,7 +324,7 @@ def get_credential_id_sha256(cls, credential_id: bytes) -> bytes: return hashlib.sha256(credential_id).digest() @classmethod - def get_credential_descriptors_for_user(cls, user: AbstractUser) -> list[PublicKeyCredentialDescriptor]: + def get_credential_descriptors_for_user(cls, user: AbstractBaseUser) -> list[PublicKeyCredentialDescriptor]: """Return a list of PublicKeyCredentialDescriptor objects for the given user. Each PublicKeyCredentialDescriptor object represents a credential that the diff --git a/src/django_otp_webauthn/utils.py b/src/django_otp_webauthn/utils.py index 8e0db31..ae65e8e 100644 --- a/src/django_otp_webauthn/utils.py +++ b/src/django_otp_webauthn/utils.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from logging import Logger -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.apps import apps from django.core.exceptions import ImproperlyConfigured @@ -23,7 +25,7 @@ class rewrite_exceptions: """ - def __init__(self, logger: Optional[Logger] = None): + def __init__(self, logger: Logger | None = None): self.logger = logger def log_exception(self, exc: Exception): diff --git a/src/django_otp_webauthn/views.py b/src/django_otp_webauthn/views.py index 60478ea..b027168 100644 --- a/src/django_otp_webauthn/views.py +++ b/src/django_otp_webauthn/views.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import lru_cache from logging import getLogger @@ -5,7 +7,8 @@ from django.contrib.auth import authenticate as auth_authenticate from django.contrib.auth import get_user_model from django.contrib.auth import login as auth_login -from django.contrib.auth.models import AbstractUser +from django.contrib.auth.base_user import AbstractBaseUser +from django.contrib.auth.models import AnonymousUser from django.shortcuts import resolve_url from django.utils.decorators import method_decorator from django.utils.http import url_has_allowed_host_and_scheme @@ -40,12 +43,12 @@ def dispatch(self, request, *args, **kwargs): raise exceptions.RegistrationDisabled() return super().dispatch(request, *args, **kwargs) - def get_user(self) -> AbstractUser: + def get_user(self) -> AbstractBaseUser | None: if self.request.user.is_authenticated: return self.request.user return None - def can_register(self, user: AbstractUser) -> bool: + def can_register(self, user: AbstractBaseUser | AnonymousUser) -> bool: if not user.is_active: return False return True @@ -58,15 +61,15 @@ def dispatch(self, request, *args, **kwargs): raise exceptions.AuthenticationDisabled() return super().dispatch(request, *args, **kwargs) - def get_user(self) -> AbstractUser: + def get_user(self) -> AbstractBaseUser | None: if self.request.user.is_authenticated: return self.request.user return None - def can_authenticate(self, user: AbstractUser) -> bool: - if user and not user.is_active: - return False - return True + def can_authenticate(self, user: AbstractBaseUser | AnonymousUser | None) -> bool: + if user and user.is_active: + return True + return False @method_decorator(never_cache, name="dispatch") @@ -187,7 +190,7 @@ def check_login_allowed(self, device: AbstractWebAuthnCredential) -> None: if not device.user.is_active: raise exceptions.UserDisabled() - def complete_auth(self, device: AbstractWebAuthnCredential) -> AbstractUser: + def complete_auth(self, device: AbstractWebAuthnCredential) -> AbstractBaseUser: """Handle the completion of the authentication procedure. This method is called when a credential was successfully used and