diff --git a/flask_jwt/__init__.py b/flask_jwt/__init__.py index f864b78..4f0bd3c 100644 --- a/flask_jwt/__init__.py +++ b/flask_jwt/__init__.py @@ -33,6 +33,7 @@ 'JWT_AUTH_USERNAME_KEY': 'username', 'JWT_AUTH_PASSWORD_KEY': 'password', 'JWT_ALGORITHM': 'HS256', + 'JWT_ROLE': 'role', 'JWT_LEEWAY': timedelta(seconds=10), 'JWT_AUTH_HEADER_PREFIX': 'JWT', 'JWT_EXPIRATION_DELTA': timedelta(seconds=300), @@ -141,7 +142,21 @@ def _default_jwt_error_handler(error): ])), error.status_code, error.headers -def _jwt_required(realm): +def _force_iterable(input): + """If role is just a string, force it to an array. + """ + try: + basestring + except NameError: + basestring = str + if isinstance(input, basestring): + return [input] + if not hasattr(input, "__iter__"): + return [input] + return input + + +def _jwt_required(realm, roles): """Does the actual work of verifying the JWT data in the current request. This is done automatically for you by `jwt_required()` but you could call it manually. Doing so would be useful in the context of optional JWT access in your APIs. @@ -163,17 +178,33 @@ def _jwt_required(realm): if identity is None: raise JWTError('Invalid JWT', 'User does not exist') - - -def jwt_required(realm=None): + if roles: + try: + identity_role = getattr(identity, current_app.config['JWT_ROLE']) + except AttributeError: + try: + identity_role = identity.get(current_app.config['JWT_ROLE']) + except AttributeError: + raise JWTError('Bad Request', 'Invalid credentials') + if not identity_role: + raise JWTError('Bad Request', 'Invalid credentials') + identity_role = _force_iterable(identity_role) + roles = _force_iterable(roles) + if not identity_role or not set(roles).intersection(identity_role): + raise JWTError('Bad Request', 'Invalid credentials') + + +def jwt_required(realm=None, roles=None): """View decorator that requires a valid JWT token to be present in the request :param realm: an optional realm + :param roles: an optional list of roles allowed, + the role is pick in JWT_ROLE field of identity """ def wrapper(fn): @wraps(fn) def decorator(*args, **kwargs): - _jwt_required(realm or current_app.config['JWT_DEFAULT_REALM']) + _jwt_required(realm or current_app.config['JWT_DEFAULT_REALM'], roles) return fn(*args, **kwargs) return decorator return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index fb87ce4..cdad37e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,7 @@ import pytest from flask import Flask +from datetime import datetime, timedelta import flask_jwt @@ -18,10 +19,12 @@ class User(object): - def __init__(self, id, username, password): + def __init__(self, id, username, password, role=None): self.id = id self.username = username self.password = password + if role: + self.role = role def __str__(self): return "User(id='%s')" % self.id @@ -37,6 +40,16 @@ def user(): return User(id=1, username='joe', password='pass') +@pytest.fixture(scope='function') +def user_with_role(): + return User(id=2, username='jane', password='pass', role='user') + + +@pytest.fixture(scope='function') +def user_with_roles(): + return User(id=3, username='alice', password='pass', role=['user', 'foo', 'bar']) + + @pytest.fixture(scope='function') def app(jwt, user): app = Flask(__name__) @@ -64,6 +77,119 @@ def protected(): return app +@pytest.fixture(scope='function') +def app_with_role(jwt, user, user_with_role, user_with_roles): + app = Flask(__name__) + app.debug = True + app.config['SECRET_KEY'] = 'super-secret' + users = [user, user_with_role, user_with_roles] + + @jwt.authentication_handler + def authenticate(username, password): + for u in users: + if username == u.username and password == u.password: + return u + return None + + @jwt.identity_handler + def load_user(payload): + for u in users: + if payload['identity'] == u.id: + return u + + @jwt.jwt_payload_handler + def make_payload(identity): + iat = datetime.utcnow() + exp = iat + timedelta(seconds=300) + nbf = iat + id = getattr(identity, 'id') + try: + role = getattr(identity, 'role') + return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id, 'role': role} + except AttributeError: + return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id} + + jwt.init_app(app) + + @app.route('/protected') + @flask_jwt.jwt_required() + def protected(): + return 'success' + + @app.route('/role/protected/admin') + @flask_jwt.jwt_required(roles='admin') + def admin_protected(): + return 'success' + + @app.route('/role/protected/multi') + @flask_jwt.jwt_required(roles=['admin', 'user']) + def admin_user_protected(): + return 'success' + + @app.route('/role/protected/user') + @flask_jwt.jwt_required(roles='user') + def user_protected(): + return 'success' + + return app + + +@pytest.fixture(scope='function') +def app_with_role_trust_jwt(jwt, user, user_with_role, user_with_roles): + app = Flask(__name__) + app.debug = True + app.config['SECRET_KEY'] = 'super-secret' + app.config['JWT_ROLE'] = 'my_role' + users = [user, user_with_role, user_with_roles] + + @jwt.authentication_handler + def authenticate(username, password): + for u in users: + if username == u.username and password == u.password: + return u + return None + + @jwt.identity_handler + def load_user(payload): + return payload + + @jwt.jwt_payload_handler + def make_payload(identity): + iat = datetime.utcnow() + exp = iat + timedelta(seconds=300) + nbf = iat + id = getattr(identity, 'id') + try: + role = getattr(identity, 'role') + return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id, 'my_role': role} + except AttributeError: + return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id} + + jwt.init_app(app) + + @app.route('/protected') + @flask_jwt.jwt_required() + def protected(): + return 'success' + + @app.route('/role/protected/user') + @flask_jwt.jwt_required(roles='user') + def user_protected(): + return 'success' + + @app.route('/role/protected/multi') + @flask_jwt.jwt_required(roles=['admin', 'user']) + def admin_user_protected(): + return 'success' + + @app.route('/role/protected/admin') + @flask_jwt.jwt_required(roles='admin') + def admin_protected(): + return 'success' + + return app + + @pytest.fixture(scope='function') def client(app): return app.test_client() diff --git a/tests/test_jwt.py b/tests/test_jwt.py index 2157003..01761e9 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -291,3 +291,146 @@ def custom_auth_request_handler(): with app.test_client() as c: resp, jdata = post_json(c, '/auth', {}) assert jdata == {'hello': 'world'} + + +def test_role_required(app_with_role, user_with_role): + with app_with_role.test_client() as c: + resp, jdata = post_json( + c, '/auth', {'username': user_with_role.username, 'password': user_with_role.password}) + token = jdata['access_token'] + + # check if protected works with role set but not asked for this path + resp = c.get('/protected', headers={'authorization': 'JWT ' + token}) + assert resp.status_code == 200 + assert resp.data == b'success' + + # check if protected works wit role set but not asked for this path + resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + + +def test_role_required_bad(app_with_role, user, user_with_role): + with app_with_role.test_client() as c: + + # test bad role + resp, jdata = post_json( + c, '/auth', {'username': user_with_role.username, 'password': user_with_role.password}) + + token = jdata['access_token'] + resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401 + + # test no role + resp, jdata = post_json( + c, '/auth', {'username': user.username, 'password': user.password}) + + token = jdata['access_token'] + resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401 + + +def test_role_required_multi(app_with_role, user_with_roles): + with app_with_role.test_client() as c: + resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username, + 'password': user_with_roles.password}) + token = jdata['access_token'] + + # check if protected works with role set but not asked for this path + resp = c.get('/protected', headers={'authorization': 'JWT ' + token}) + assert resp.status_code == 200 + assert resp.data == b'success' + + resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + + +def test_role_required_multi_bad(app_with_role, user_with_roles): + with app_with_role.test_client() as c: + resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username, + 'password': user_with_roles.password}) + + token = jdata['access_token'] + resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401 + + +def test_multirole_required_multi(app_with_role, user, user_with_roles): + with app_with_role.test_client() as c: + resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username, + 'password': user_with_roles.password}) + token = jdata['access_token'] + + # check if protected works with role set but not asked for this path + resp = c.get('/protected', headers={'authorization': 'JWT ' + token}) + assert resp.status_code == 200 + assert resp.data == b'success' + + resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + + # test no role + resp, jdata = post_json( + c, '/auth', {'username': user.username, 'password': user.password}) + + token = jdata['access_token'] + resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401 + + +def test_role_custom(app_with_role_trust_jwt, user, user_with_role, user_with_roles): + with app_with_role_trust_jwt.test_client() as c: + resp, jdata = post_json(c, '/auth', {'username': user_with_role.username, + 'password': user_with_role.password}) + token = jdata['access_token'] + + # check if protected works with role set but not asked for this path + resp = c.get('/protected', headers={'authorization': 'JWT ' + token}) + assert resp.status_code == 200 + assert resp.data == b'success' + + # check unauthorized role protection + resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401 + + resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + + resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + + resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username, + 'password': user_with_roles.password}) + token = jdata['access_token'] + + # check if protected works with role set but not asked for this path + resp = c.get('/protected', headers={'authorization': 'JWT ' + token}) + assert resp.status_code == 200 + assert resp.data == b'success' + + resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 200 + assert resp.data == b'success' + # test no role + resp, jdata = post_json( + c, '/auth', {'username': user.username, 'password': user.password}) + + token = jdata['access_token'] + resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token}) + + assert resp.status_code == 401