Skip to content

Commit

Permalink
feat: implement RFC7523 client JWT authentication and grant
Browse files Browse the repository at this point in the history
  • Loading branch information
azmeuk committed Feb 10, 2025
1 parent 140ef2b commit 5d7438e
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Added
^^^^^
- Implement :rfc:`RFC9207 <9207>`. :pr:`227`
- Implement :rfc:`RFC7523 <7523>`. :issue:`112`

Fixed
^^^^^
Expand Down
7 changes: 7 additions & 0 deletions canaille/app/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class RootSettings(BaseSettings):
avoided in production environments.
"""

CACHE_TYPE: str = "SimpleCache"
"""The cache type.
The default ``SimpleCache`` is a lightweight in-memory cache.
See the :ref:`Flask-Caching documentation <flask-caching>` for further details.
"""


def settings_factory(
config=None,
Expand Down
3 changes: 3 additions & 0 deletions canaille/app/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from flask import request
from flask import session
from flask import url_for
from flask_caching import Cache
from flask_wtf.csrf import CSRFProtect
from werkzeug.exceptions import HTTPException
from werkzeug.routing import BaseConverter
Expand All @@ -21,6 +22,7 @@
from canaille.app.templating import render_template

csrf = CSRFProtect()
cache = Cache()


def user_needed(*args):
Expand Down Expand Up @@ -169,6 +171,7 @@ def setup_flask(app):
from canaille.app.templating import render_template

csrf.init_app(app)
cache.init_app(app)

@app.before_request
def make_session_permanent():
Expand Down
3 changes: 3 additions & 0 deletions canaille/oidc/endpoints/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ClientAddForm(Form):
("hybrid", "hybrid"),
("refresh_token", "refresh_token"),
("client_credentials", "client_credentials"),
("urn:ietf:params:oauth:grant-type:jwt-bearer", "jwt-bearer"),
],
default=["authorization_code", "refresh_token"],
)
Expand All @@ -121,6 +122,8 @@ class ClientAddForm(Form):
choices=[
("client_secret_basic", "client_secret_basic"),
("client_secret_post", "client_secret_post"),
("client_secret_jwt", "client_secret_jwt"),
("private_key_jwt", "private_key_jwt"),
("none", "none"),
],
default="client_secret_basic",
Expand Down
81 changes: 80 additions & 1 deletion canaille/oidc/oauth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import datetime
import json

import requests
from authlib.integrations.flask_oauth2 import AuthorizationServer
from authlib.integrations.flask_oauth2 import ResourceProtector
from authlib.jose import JsonWebKey
from authlib.oauth2.rfc6749 import InvalidClientError
from authlib.oauth2.rfc6749.grants import (
AuthorizationCodeGrant as _AuthorizationCodeGrant,
)
Expand All @@ -14,6 +17,8 @@
)
from authlib.oauth2.rfc6750 import BearerTokenValidator as _BearerTokenValidator
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
from authlib.oauth2.rfc7523 import JWTBearerClientAssertion
from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant
from authlib.oauth2.rfc7591 import (
ClientRegistrationEndpoint as _ClientRegistrationEndpoint,
)
Expand All @@ -36,9 +41,11 @@

from canaille.app import DOCUMENTATION_URL
from canaille.app import models
from canaille.app.flask import cache
from canaille.backends import Backend

AUTHORIZATION_CODE_LIFETIME = 84400
JWT_JTI_CACHE_LIFETIME = 3600


def oauth_authorization_server():
Expand Down Expand Up @@ -168,6 +175,28 @@ def get_jwks():
}


def get_client_jwks(client, kid=None):
"""Get the client JWK set, either stored locally or by downloading them from the URI the client indicated."""

@cache.cached(timeout=50, key_prefix=f"jwks_{client.client_id}")
def get_jwks():
return requests.get(client.jwks_uri).json()

if client.jwks_uri:
raw_jwks = get_jwks()
key_set = JsonWebKey.import_key_set(raw_jwks)
jwk = key_set.find_by_kid(kid)
return jwk

if client.jwks:
raw_jwks = json.loads(client.jwks)
key_set = JsonWebKey.import_key_set(raw_jwks)
jwk = key_set.find_by_kid(kid)
return jwk

return None


def claims_from_scope(scope):
claims = {"sub"}
if "profile" in scope:
Expand Down Expand Up @@ -246,8 +275,30 @@ def save_authorization_code(code, request):
return code.code


class JWTClientAuth(JWTBearerClientAssertion):
def validate_jti(self, claims, jti):
"""Indicate whether the jti was used before."""
key = "jti:{}-{}".format(claims["sub"], jti)
if cache.get(key):
return False
cache.set(key, 1, timeout=JWT_JTI_CACHE_LIFETIME)
return True

def resolve_client_public_key(self, client, headers):
jwk = get_client_jwks(client)
if not jwk:
raise InvalidClientError(description="No matching JWK")

return jwk


class AuthorizationCodeGrant(_AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"]
TOKEN_ENDPOINT_AUTH_METHODS = [
"client_secret_basic",
"client_secret_post",
JWTClientAuth.CLIENT_AUTH_METHOD,
"none",
]

def save_authorization_code(self, code, request):
return save_authorization_code(code, request)
Expand Down Expand Up @@ -314,6 +365,27 @@ def revoke_old_credential(self, credential):
Backend.instance.save(credential)


class JWTBearerGrant(_JWTBearerGrant):
def resolve_issuer_client(self, issuer):
return Backend.instance.get(models.Client, client_id=issuer)

def resolve_client_key(self, client, headers, payload):
jwk = get_client_jwks(client, headers.get("kid"))
if not jwk:
raise InvalidClientError(description="No matching JWK")
return jwk

def authenticate_user(self, subject: str):
return Backend.instance.get_user_from_login(subject)

def has_granted_permission(self, client, user):
grant = Backend.instance.get(models.Consent, client=client, subject=user)
has_permission = (grant and not grant.revoked) or (
not grant and client.preconsent
)
return has_permission


class OpenIDImplicitGrant(_OpenIDImplicitGrant):
def exists_nonce(self, nonce, request):
return exists_nonce(nonce, request)
Expand Down Expand Up @@ -562,8 +634,15 @@ def setup_oauth(app):
authorization.register_grant(PasswordGrant)
authorization.register_grant(ImplicitGrant)
authorization.register_grant(RefreshTokenGrant)
authorization.register_grant(JWTBearerGrant)
authorization.register_grant(ClientCredentialsGrant)

with app.app_context():
authorization.register_client_auth_method(
JWTClientAuth.CLIENT_AUTH_METHOD,
JWTClientAuth(url_for("oidc.endpoints.issue_token", _external=True)),
)

authorization.register_grant(
AuthorizationCodeGrant,
[
Expand Down
2 changes: 2 additions & 0 deletions demo/demoapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def populate(app):
"authorization_code",
"refresh_token",
"client_credentials",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
],
scope=["openid", "profile", "email", "groups", "address", "phone"],
response_types=["code", "id_token"],
Expand Down Expand Up @@ -155,6 +156,7 @@ def populate(app):
"authorization_code",
"refresh_token",
"client_credentials",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
],
scope=["openid", "profile", "email", "groups", "address", "phone"],
response_types=["code", "id_token"],
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __getattr__(cls, name):
"flask": ("https://flask.palletsprojects.com", None),
"flask-alembic": ("https://flask-alembic.readthedocs.io/en/latest", None),
"flask-babel": ("https://python-babel.github.io/flask-babel", None),
"flask-caching": ("https://flask-caching.readthedocs.io/en/latest/", None),
"flask-wtf": ("https://flask-wtf.readthedocs.io", None),
"hypercorn": ("https://hypercorn.readthedocs.io/en/latest", None),
"pydantic": ("https://docs.pydantic.dev/latest", None),
Expand Down
2 changes: 1 addition & 1 deletion doc/development/specifications.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ OAuth2
- ✅ `RFC6749: OAuth 2.0 Framework <https://tools.ietf.org/html/rfc6749>`_
- ✅ `RFC6750: OAuth 2.0 Bearer Tokens <https://tools.ietf.org/html/rfc6750>`_
- ✅ `RFC7009: OAuth 2.0 Token Revocation <https://tools.ietf.org/html/rfc7009>`_
- `RFC7523: JWT Profile for OAuth 2.0 Client Authentication and Authorization Grants <https://tools.ietf.org/html/rfc7523>`_
- `RFC7523: JWT Profile for OAuth 2.0 Client Authentication and Authorization Grants <https://tools.ietf.org/html/rfc7523>`_
- ✅ `RFC7591: OAuth 2.0 Dynamic Client Registration Protocol <https://tools.ietf.org/html/rfc7591>`_
- ✅ `RFC7592: OAuth 2.0 Dynamic Client Registration Management Protocol <https://tools.ietf.org/html/rfc7592>`_
- ✅ `RFC7636: Proof Key for Code Exchange by OAuth Public Clients <https://tools.ietf.org/html/rfc7636>`_
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ maintainers = [{name="Éloi Rivard", email="[email protected]"}]
requires-python = ">=3.10"
dependencies = [
"flask >= 3.0.0",
"flask-caching>=2.3.0",
"flask-wtf >= 1.2.1",
"pydantic-settings >= 2.0.3",
"requests>=2.32.3",
Expand Down
6 changes: 6 additions & 0 deletions tests/app/fixtures/current-app-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ PREFERRED_URL_SCHEME = "http"
# avoided in production environments.
# DEBUG = false

# The cache type.
#
# The default SimpleCache is a lightweight in-memory cache. See the Flask-Caching
# documentation for further details.
# CACHE_TYPE = "SimpleCache"

[CANAILLE]
# Your organization name.
#
Expand Down
6 changes: 6 additions & 0 deletions tests/app/fixtures/default-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
# avoided in production environments.
# DEBUG = false

# The cache type.
#
# The default SimpleCache is a lightweight in-memory cache. See the Flask-Caching
# documentation for further details.
# CACHE_TYPE = "SimpleCache"

[CANAILLE]
# Your organization name.
#
Expand Down
2 changes: 2 additions & 0 deletions tests/oidc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def client(testclient, trusted_client, backend, client_jwks):
"hybrid",
"refresh_token",
"client_credentials",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
],
response_types=["code", "token", "id_token"],
scope=["openid", "email", "profile", "groups", "address", "phone"],
Expand Down Expand Up @@ -116,6 +117,7 @@ def trusted_client(testclient, backend, client_jwks):
"hybrid",
"refresh_token",
"client_credentials",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
],
response_types=["code", "token", "id_token"],
scope=["openid", "profile", "groups"],
Expand Down
1 change: 1 addition & 0 deletions tests/oidc/test_dynamic_client_registration_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_get(testclient, backend, client, user, client_jwks):
"hybrid",
"refresh_token",
"client_credentials",
"urn:ietf:params:oauth:grant-type:jwt-bearer",
],
"response_types": ["code", "token", "id_token"],
"client_name": "Some client",
Expand Down
87 changes: 87 additions & 0 deletions tests/oidc/test_jwt_authorization_grant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import time
import uuid

from joserfc import jwt

from canaille.app import models


def test_nominal_case(testclient, logged_user, client, backend, client_jwks):
"""Test JWT grant for a client with consent."""
now = time.time()
client.preconsent = True
backend.save(client)

_, private_key = client_jwks
header = {"alg": "RS256"}
payload = {
"iss": client.client_id,
"sub": logged_user.user_name,
"aud": "http://canaille.test/oauth/token",
"nbf": now - 3600,
"exp": now + 3600,
"iat": now - 1,
"jti": str(uuid.uuid4()),
}
client_jwt = jwt.encode(header, payload, private_key)

res = testclient.post(
"/oauth/token",
params=dict(
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
scope="openid profile email groups address phone",
assertion=client_jwt,
redirect_uri=client.redirect_uris[0],
),
status=200,
)

access_token = res.json["access_token"]
token = backend.get(models.Token, access_token=access_token)
assert token.client == client
assert token.subject == logged_user
assert set(token.scope) == {
"openid",
"profile",
"email",
"groups",
"address",
"phone",
}


def test_no_jwk(testclient, logged_user, client, backend, client_jwks):
"""Test JWT grant for a client without JWKs."""
now = time.time()
client.preconsent = True
client.jwks = None
backend.save(client)

_, private_key = client_jwks
header = {"alg": "RS256"}
payload = {
"iss": client.client_id,
"sub": logged_user.user_name,
"aud": "http://canaille.test/oauth/token",
"nbf": now - 3600,
"exp": now + 3600,
"iat": now - 1,
"jti": str(uuid.uuid4()),
}
client_jwt = jwt.encode(header, payload, private_key)

res = testclient.post(
"/oauth/token",
params=dict(
grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer",
scope="openid profile email groups address phone",
assertion=client_jwt,
redirect_uri=client.redirect_uris[0],
),
status=400,
)

assert res.json == {
"error": "invalid_client",
"error_description": "No matching JWK",
}
Loading

0 comments on commit 5d7438e

Please sign in to comment.