Skip to content

Commit

Permalink
Updated the type definitions specifying AbstractUser to instead use A…
Browse files Browse the repository at this point in the history
…bstractBaseUser (#9)
  • Loading branch information
Stormheg authored Jul 18, 2024
2 parents bae723b + 5a62ed8 commit 0169de8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 29 deletions.
33 changes: 17 additions & 16 deletions src/django_otp_webauthn/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import hashlib
import json
from typing import Optional

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser
from django.contrib.auth.base_user import AbstractBaseUser
from django.http import HttpRequest
from webauthn import (
base64url_to_bytes,
Expand Down Expand Up @@ -174,7 +175,7 @@ def get_attestation_conveyance_preference(self) -> AttestationConveyancePreferen

def get_authenticator_attachment_preference(
self,
) -> Optional[AuthenticatorAttachment]:
) -> AuthenticatorAttachment | None:
"""Get the authenticator attachment preference.
By default, this is set to None, which means we don't have a preference.
Expand All @@ -184,7 +185,7 @@ def get_authenticator_attachment_preference(
"""
return None

def get_credential_display_name(self, user: AbstractUser) -> str:
def get_credential_display_name(self, user: AbstractBaseUser) -> str:
"""Get the display name for the credential.
This is used to display a name during registration and authentication.
Expand All @@ -193,12 +194,12 @@ def get_credential_display_name(self, user: AbstractUser) -> str:
User.get_username().
"""

if user.get_full_name():
if hasattr(user, "get_full_name"):
return f"{user.get_full_name()} ({user.get_username()})"

return user.get_username()

def get_credential_name(self, user: AbstractUser) -> str:
def get_credential_name(self, user: AbstractBaseUser) -> str:
"""Get the name for the credential.
This is used to display the user's name during registration and
Expand All @@ -208,7 +209,7 @@ def get_credential_name(self, user: AbstractUser) -> str:
"""
return user.get_username()

def get_unique_anonymous_user_id(self, user: AbstractUser) -> bytes:
def get_unique_anonymous_user_id(self, user: AbstractBaseUser) -> bytes:
"""Get a unique identifier for the user to use during WebAuthn
ceremonies. It must be a unique byte sequence no longer than 64 bytes.
Expand Down Expand Up @@ -239,7 +240,7 @@ def get_unique_anonymous_user_id(self, user: AbstractUser) -> bytes:
# bytes instead.
return hashlib.sha256(bytes(user.pk)).digest()

def get_user_entity(self, user: AbstractUser) -> PublicKeyCredentialUserEntity:
def get_user_entity(self, user: AbstractBaseUser) -> PublicKeyCredentialUserEntity:
"""Get information about the user account a credential is being registered for."""
return PublicKeyCredentialUserEntity(
id=self.get_unique_anonymous_user_id(user),
Expand Down Expand Up @@ -270,7 +271,7 @@ def get_supported_key_algorithms(self) -> list[COSEAlgorithmIdentifier] | None:
algorithms = [COSEAlgorithmIdentifier(a) for a in raw_algorithms if a in COSEAlgorithmIdentifier]
return algorithms

def get_generate_registration_options_kwargs(self, *, user: AbstractUser) -> dict:
def get_generate_registration_options_kwargs(self, *, user: AbstractBaseUser) -> dict:
"""Get the keyword arguments to pass to `webauthn.generate_registration_options`."""
challenge = self.generate_challenge()
rp = self.get_relying_party()
Expand Down Expand Up @@ -323,7 +324,7 @@ def get_registration_state(self, creation_options: dict) -> dict:
== UserVerificationRequirement.REQUIRED.value,
}

def register_begin(self, user: AbstractUser) -> tuple[dict, dict]:
def register_begin(self, user: AbstractBaseUser) -> tuple[dict, dict]:
"""Begin the registration process."""

kwargs = self.get_generate_registration_options_kwargs(user=user)
Expand All @@ -349,7 +350,7 @@ def get_allowed_origins(self) -> list[str]:
origins = app_settings.OTP_WEBAUTHN_ALLOWED_ORIGINS
return origins

def register_complete(self, user: AbstractUser, state: dict, data: dict):
def register_complete(self, user: AbstractBaseUser, state: dict, data: dict):
"""Complete the registration process."""
credential = parse_registration_credential_json(data)

Expand Down Expand Up @@ -377,7 +378,7 @@ def register_complete(self, user: AbstractUser, state: dict, data: dict):
self.create_attestation(device, response.attestation_object, credential.response.client_data_json)
return device

def _check_discoverable(self, original_data: dict) -> Optional[bool]:
def _check_discoverable(self, original_data: dict) -> bool | None:
"""Check the clientExtensionResults to determine if the credential was
created as discoverable.
Expand All @@ -397,7 +398,7 @@ def _check_discoverable(self, original_data: dict) -> Optional[bool]:

def create_credential(
self,
user: AbstractUser,
user: AbstractBaseUser,
response: VerifiedRegistration,
parsed_credential: RegistrationCredential,
original_data: dict,
Expand Down Expand Up @@ -454,7 +455,7 @@ def get_authentication_extensions(self) -> dict:
return {}

def get_generate_authentication_options_kwargs(
self, *, user: Optional[AbstractUser] = None, require_user_verification: bool
self, *, user: AbstractBaseUser | None = None, require_user_verification: bool
) -> dict:
"""Get the keyword arguments to pass to `webauth.generate_authentication_options`."""

Expand Down Expand Up @@ -484,7 +485,7 @@ def get_authentication_state(self, options: dict) -> dict:

def authenticate_begin(
self,
user: Optional[AbstractUser] = None,
user: AbstractBaseUser | None = None,
require_user_verification: bool = True,
):
"""Begin the authentication process."""
Expand All @@ -507,7 +508,7 @@ def authenticate_begin(
state = self.get_authentication_state(data)
return data, state

def authenticate_complete(self, user: Optional[AbstractUser], state: dict, data: dict):
def authenticate_complete(self, user: AbstractBaseUser | None, state: dict, data: dict):
"""Complete the authentication process."""

credential = parse_authentication_credential_json(data)
Expand Down
4 changes: 2 additions & 2 deletions src/django_otp_webauthn/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import hashlib

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser
from django.contrib.auth.base_user import AbstractBaseUser
from django.db import models
from django.db.models import QuerySet
from django.http import HttpRequest
Expand Down Expand Up @@ -324,7 +324,7 @@ def get_credential_id_sha256(cls, credential_id: bytes) -> bytes:
return hashlib.sha256(credential_id).digest()

@classmethod
def get_credential_descriptors_for_user(cls, user: AbstractUser) -> list[PublicKeyCredentialDescriptor]:
def get_credential_descriptors_for_user(cls, user: AbstractBaseUser) -> list[PublicKeyCredentialDescriptor]:
"""Return a list of PublicKeyCredentialDescriptor objects for the given user.
Each PublicKeyCredentialDescriptor object represents a credential that the
Expand Down
6 changes: 4 additions & 2 deletions src/django_otp_webauthn/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from logging import Logger
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from django.apps import apps
from django.core.exceptions import ImproperlyConfigured
Expand All @@ -23,7 +25,7 @@ class rewrite_exceptions:
"""

def __init__(self, logger: Optional[Logger] = None):
def __init__(self, logger: Logger | None = None):
self.logger = logger

def log_exception(self, exc: Exception):
Expand Down
21 changes: 12 additions & 9 deletions src/django_otp_webauthn/views.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from functools import lru_cache
from logging import getLogger

from django.conf import settings
from django.contrib.auth import authenticate as auth_authenticate
from django.contrib.auth import get_user_model
from django.contrib.auth import login as auth_login
from django.contrib.auth.models import AbstractUser
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.models import AnonymousUser
from django.shortcuts import resolve_url
from django.utils.decorators import method_decorator
from django.utils.http import url_has_allowed_host_and_scheme
Expand Down Expand Up @@ -40,12 +43,12 @@ def dispatch(self, request, *args, **kwargs):
raise exceptions.RegistrationDisabled()
return super().dispatch(request, *args, **kwargs)

def get_user(self) -> AbstractUser:
def get_user(self) -> AbstractBaseUser | None:
if self.request.user.is_authenticated:
return self.request.user
return None

def can_register(self, user: AbstractUser) -> bool:
def can_register(self, user: AbstractBaseUser | AnonymousUser) -> bool:
if not user.is_active:
return False
return True
Expand All @@ -58,15 +61,15 @@ def dispatch(self, request, *args, **kwargs):
raise exceptions.AuthenticationDisabled()
return super().dispatch(request, *args, **kwargs)

def get_user(self) -> AbstractUser:
def get_user(self) -> AbstractBaseUser | None:
if self.request.user.is_authenticated:
return self.request.user
return None

def can_authenticate(self, user: AbstractUser) -> bool:
if user and not user.is_active:
return False
return True
def can_authenticate(self, user: AbstractBaseUser | AnonymousUser | None) -> bool:
if user and user.is_active:
return True
return False


@method_decorator(never_cache, name="dispatch")
Expand Down Expand Up @@ -187,7 +190,7 @@ def check_login_allowed(self, device: AbstractWebAuthnCredential) -> None:
if not device.user.is_active:
raise exceptions.UserDisabled()

def complete_auth(self, device: AbstractWebAuthnCredential) -> AbstractUser:
def complete_auth(self, device: AbstractWebAuthnCredential) -> AbstractBaseUser:
"""Handle the completion of the authentication procedure.
This method is called when a credential was successfully used and
Expand Down

0 comments on commit 0169de8

Please sign in to comment.