diff --git a/conftest.py b/conftest.py index 7e51b8e..c9943cd 100644 --- a/conftest.py +++ b/conftest.py @@ -16,6 +16,7 @@ def saml_request_minimal() -> str: @pytest.fixture() @lru_cache() -def sp_metadata_xml() -> str: - with (XML_ROOT / "metadata/sp_metadata.xml").open("r") as f: +def sp_metadata_xml(request) -> str: + file_name = getattr(request, "param", "sp_metadata") + with (XML_ROOT / f"metadata/{file_name}.xml").open("r") as f: return f.read() diff --git a/djangosaml2idp/conf.py b/djangosaml2idp/conf.py new file mode 100644 index 0000000..e52d7a3 --- /dev/null +++ b/djangosaml2idp/conf.py @@ -0,0 +1,37 @@ +import copy +from typing import Callable, Optional, Union + +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest +from django.utils.module_loading import import_string + + +def get_callable(path: Union[Callable, str]) -> Callable: + """ Import the function at a given path and return it + """ + if callable(path): + return path + + try: + config_loader = import_string(path) + except ImportError as e: + raise ImproperlyConfigured(f'Error importing SAML config loader {path}: "{e}"') + + if not callable(config_loader): + raise ImproperlyConfigured("SAML config loader must be a callable object.") + + return config_loader + + +def get_config(config_loader_path: Optional[Union[Callable, str]] = None, request: Optional[HttpRequest] = None) -> dict: + """ Load a config_loader function if necessary, and call that function with the request as argument. + If the config_loader_path is a callable instead of a string, no importing is necessary and it will be used directly. + Return the resulting SPConfig. + """ + static_config = copy.deepcopy(settings.SAML_IDP_CONFIG) + + if config_loader_path is None: + return static_config or {} + else: + return get_callable(config_loader_path)(static_config, request) diff --git a/djangosaml2idp/idp.py b/djangosaml2idp/idp.py index 380d2f0..541f03c 100644 --- a/djangosaml2idp/idp.py +++ b/djangosaml2idp/idp.py @@ -1,53 +1,68 @@ -import copy - from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest from django.utils.translation import gettext as _ from saml2.config import IdPConfig from saml2.metadata import entity_descriptor from saml2.server import Server +from typing import Callable, Dict, Optional, Union + +from .conf import get_callable, get_config class IDP: """ Access point for the IDP Server instance """ - _server_instance: Server = None + _server_instances: Dict[str, Server] = {} @classmethod - def construct_metadata(cls, with_local_sp: bool = True) -> dict: + def construct_metadata(cls, idp_conf: dict, request: Optional[HttpRequest] = None, with_local_sp: bool = True) -> IdPConfig: """ Get the config including the metadata for all the configured service providers. """ + conf = IdPConfig() + from .models import ServiceProvider - idp_config = copy.deepcopy(settings.SAML_IDP_CONFIG) - if idp_config: - idp_config['metadata'] = { # type: ignore - 'local': ( - [sp.metadata_path() for sp in ServiceProvider.objects.filter(active=True)] - if with_local_sp else []), - } - return idp_config + sp_queryset = ServiceProvider.objects.none() + if with_local_sp: + sp_queryset = ServiceProvider.objects.filter(active=True) + if getattr(settings, "SAML_IDP_FILTER_SP_QUERYSET", None) is not None: + sp_queryset = get_callable(settings.SAML_IDP_FILTER_SP_QUERYSET)(sp_queryset, request) + + idp_conf['metadata'] = { # type: ignore + 'local': ( + [sp.metadata_path() for sp in sp_queryset] + if with_local_sp else [] + ), + } + try: + conf.load(idp_conf) + except Exception as e: + raise ImproperlyConfigured(_('Could not instantiate an IDP based on the SAML_IDP_CONFIG settings and configured ServiceProviders: {}').format(str(e))) + return conf + + @classmethod + def load(cls, request: Optional[HttpRequest] = None, config_loader_path: Optional[Union[Callable, str]] = None) -> Server: + idp_conf = get_config(config_loader_path, request) + if "entityid" not in idp_conf: + raise ImproperlyConfigured('The configuration must contain an entityid') + entity_id = idp_conf["entityid"] + + if entity_id not in cls._server_instances: + # actually initialize the IdP server and cache it + conf = cls.construct_metadata(idp_conf, request) + cls._server_instances[entity_id] = Server(config=conf) + + return cls._server_instances[entity_id] @classmethod - def load(cls, force_refresh: bool = False) -> Server: - """ Instantiate a IDP Server instance based on the config defined in the SAML_IDP_CONFIG settings. - Throws an ImproperlyConfigured exception if it could not do so for any reason. - """ - if cls._server_instance is None or force_refresh: - conf = IdPConfig() - md = cls.construct_metadata() - try: - conf.load(md) - cls._server_instance = Server(config=conf) - except Exception as e: - raise ImproperlyConfigured(_('Could not instantiate an IDP based on the SAML_IDP_CONFIG settings and configured ServiceProviders: {}').format(str(e))) - return cls._server_instance + def flush(cls): + cls._server_instances = {} @classmethod - def metadata(cls) -> str: + def metadata(cls, request: Optional[HttpRequest] = None, config_loader_path: Optional[Union[Callable, str]] = None) -> str: """ Get the IDP metadata as a string. """ - conf = IdPConfig() try: - conf.load(cls.construct_metadata(with_local_sp=False)) + conf = cls.construct_metadata(get_config(config_loader_path, request), request, with_local_sp=False) metadata = entity_descriptor(conf) except Exception as e: - raise ImproperlyConfigured(_('Could not instantiate IDP metadata based on the SAML_IDP_CONFIG settings and configured ServiceProviders: {}').format(str(e))) + raise ImproperlyConfigured(_('Could not instantiate IDP metadata: {}').format(str(e))) return str(metadata) diff --git a/djangosaml2idp/models.py b/djangosaml2idp/models.py index 449fa31..79f2771 100644 --- a/djangosaml2idp/models.py +++ b/djangosaml2idp/models.py @@ -173,7 +173,7 @@ def save(self, *args, **kwargs): if not self.metadata_expiration_dt: self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata) super().save(*args, **kwargs) - IDP.load(force_refresh=True) + IDP.flush() @property def attribute_mapping(self) -> Dict[str, str]: @@ -228,14 +228,10 @@ def metadata_path(self) -> str: @property def sign_response(self) -> bool: - if self._sign_response is None: - return getattr(IDP.load().config, "sign_response", False) return self._sign_response @property def sign_assertion(self) -> bool: - if self._sign_assertion is None: - return getattr(IDP.load().config, "sign_assertion", False) return self._sign_assertion @property diff --git a/djangosaml2idp/urls.py b/djangosaml2idp/urls.py index b6cba83..5d375e1 100644 --- a/djangosaml2idp/urls.py +++ b/djangosaml2idp/urls.py @@ -10,5 +10,5 @@ path('login/process/', views.LoginProcessView.as_view(), name='saml_login_process'), path('login/process_multi_factor/', views.get_multifactor, name='saml_multi_factor'), path('slo//', views.LogoutProcessView.as_view(), name="saml_logout_binding"), - path('metadata/', views.metadata, name='saml2_idp_metadata'), + path('metadata/', views.MetadataView.as_view(), name='saml2_idp_metadata'), ] diff --git a/djangosaml2idp/views.py b/djangosaml2idp/views.py index 401ca76..41274f8 100644 --- a/djangosaml2idp/views.py +++ b/djangosaml2idp/views.py @@ -25,6 +25,7 @@ from saml2.authn_context import PASSWORD, AuthnBroker, authn_context_class_ref from saml2.ident import NameID from saml2.saml import NAMEID_FORMAT_UNSPECIFIED +from saml2.server import Server from .error_views import error_cbv from .idp import IDP @@ -83,15 +84,22 @@ def check_access(processor: BaseProcessor, request: HttpRequest) -> None: raise PermissionDenied(_("You do not have access to this resource")) -def get_sp_config(sp_entity_id: str) -> ServiceProvider: - """ Get a dict with the configuration for a SP according to the SAML_IDP_SPCONFIG settings. +def get_sp_config(sp_entity_id: str, idp_server: Server) -> ServiceProvider: + """ Get a dict with the configuration for a SP according to the SAML_IDP_SPCONFIG settings and the SP model. Raises an exception if no SP matching the given entity id can be found. """ try: + if sp_entity_id not in idp_server.metadata.keys(): + raise ObjectDoesNotExist() sp = ServiceProvider.objects.get(entity_id=sp_entity_id, active=True) except ObjectDoesNotExist: - raise ImproperlyConfigured(_("No active Service Provider object matching the entity_id '{}' found").format(sp_entity_id)) - return sp + raise ObjectDoesNotExist( + _("No active Service Provider object matching the entity_id '{}' found for the Identity Provider '{}").format( + sp_entity_id, idp_server.ident.name_qualifier + ) + ) + else: + return sp def get_authn(req_info=None): @@ -101,7 +109,7 @@ def get_authn(req_info=None): return broker.get_authn_by_accr(req_authn_context) -def build_authn_response(user: User, authn, resp_args, service_provider: ServiceProvider) -> list: # type: ignore +def build_authn_response(user: User, authn, resp_args, service_provider: ServiceProvider, idp_server: Server) -> list: # type: ignore """ pysaml2 server.Server.create_authn_response wrapper """ policy = resp_args.get('name_id_policy', None) @@ -110,7 +118,6 @@ def build_authn_response(user: User, authn, resp_args, service_provider: Service else: name_id_format = policy.format - idp_server = IDP.load() idp_name_id_format_list = idp_server.config.getattr("name_id_format", "idp") or [NAMEID_FORMAT_UNSPECIFIED] if name_id_format not in idp_name_id_format_list: @@ -127,8 +134,8 @@ def build_authn_response(user: User, authn, resp_args, service_provider: Service userid=user_id, sp_entity_id=service_provider.entity_id, # Signing - sign_response=service_provider.sign_response, - sign_assertion=service_provider.sign_assertion, + sign_response=service_provider.sign_response if service_provider.sign_response is not None else getattr(idp_server, 'sign_response', False), + sign_assertion=service_provider.sign_assertion if service_provider.sign_assertion is not None else getattr(idp_server, 'sign_assertion', False), sign_alg=service_provider.signing_algorithm, digest_alg=service_provider.digest_algorithm, # Encryption @@ -139,8 +146,18 @@ def build_authn_response(user: User, authn, resp_args, service_provider: Service class IdPHandlerViewMixin: - """ Contains some methods used by multiple views """ + config_loader_path = getattr(settings, 'SAML_IDP_CONFIG_LOADER', None) + + def get_config_loader_path(self, request: HttpRequest): + return self.config_loader_path + + def get_idp_server(self, request: HttpRequest) -> Server: + return IDP.load(request, self.get_config_loader_path(request)) + def get_idp_metadata(self, request: HttpRequest) -> str: + return IDP.metadata(request, self.get_config_loader_path(request)) + + """ Contains some methods used by multiple views """ def render_login_html_to_string(self, context=None, request=None, using=None): """ Render the html response for the login action. Can be using a custom html template if set on the view. """ default_login_template_name = 'djangosaml2idp/login.html' @@ -179,7 +196,7 @@ def create_html_response(self, request: HttpRequest, binding, authn_resp, destin "type": "POST", } else: - idp_server = IDP.load() + idp_server = self.get_idp_server(request) http_args = idp_server.apply_binding( binding=binding, msg_str=authn_resp, @@ -230,7 +247,7 @@ def get(self, request, *args, **kwargs): # TODO: would it be better to store SAML info in request objects? # AuthBackend takes request obj as argument... try: - idp_server = IDP.load() + idp_server = self.get_idp_server(request) # Parse incoming request req_info = idp_server.parse_authn_request(request.session['SAMLRequest'], binding) @@ -245,15 +262,17 @@ def get(self, request, *args, **kwargs): resp_args = idp_server.response_args(req_info.message) # Set SP and Processor sp_entity_id = resp_args.pop('sp_entity_id') - service_provider = get_sp_config(sp_entity_id) + service_provider = get_sp_config(sp_entity_id, idp_server) # Check if user has access try: # Check if user has access to SP check_access(service_provider.processor, request) + except (ObjectDoesNotExist) as excp: + return error_cbv.handle_error(request, exception=excp, status_code=404) except PermissionDenied as excp: return error_cbv.handle_error(request, exception=excp, status_code=403) # Construct SamlResponse message - authn_resp = build_authn_response(request.user, get_authn(), resp_args, service_provider) + authn_resp = build_authn_response(request.user, get_authn(), resp_args, service_provider, idp_server) except Exception as e: return error_cbv.handle_error(request, exception=e, status_code=500) @@ -280,11 +299,15 @@ def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: request_data = request.POST or request.GET passed_data: Dict[str, Union[str, List[str]]] = request_data.copy().dict() + idp_server = self.get_idp_server(request) + try: # get sp information from the parameters sp_entity_id = str(passed_data['sp']) - service_provider = get_sp_config(sp_entity_id) + service_provider = get_sp_config(sp_entity_id, idp_server) processor: BaseProcessor = service_provider.processor # type: ignore + except (ObjectDoesNotExist) as excp: + return error_cbv.handle_error(request, exception=excp, status_code=404) except (KeyError, ImproperlyConfigured) as excp: return error_cbv.handle_error(request, exception=excp, status_code=400) @@ -294,8 +317,6 @@ def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: except PermissionDenied as excp: return error_cbv.handle_error(request, exception=excp, status_code=403) - idp_server = IDP.load() - binding_out, destination = idp_server.pick_binding( service="assertion_consumer_service", entity_id=sp_entity_id) @@ -305,7 +326,7 @@ def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: passed_data['in_response_to'] = "IdP_Initiated_Login" # Construct SamlResponse messages - authn_resp = build_authn_response(request.user, get_authn(), passed_data, service_provider) + authn_resp = build_authn_response(request.user, get_authn(), passed_data, service_provider, idp_server) html_response = self.create_html_response(request, binding_out, authn_resp, destination, passed_data.get('RelayState', "")) return self.render_response(request, html_response, processor) @@ -354,7 +375,7 @@ def get(self, request: HttpRequest, *args, **kwargs): relay_state = request.session['RelayState'] logger.debug("--- {} requested [\n{}] to IDP ---".format(self.__service_name, binding)) - idp_server = IDP.load() + idp_server = self.get_idp_server(request) # adapted from pysaml2 examples/idp2/idp_uwsgi.py try: @@ -414,6 +435,16 @@ def get(self, request: HttpRequest, *args, **kwargs): return self.render_response(request, html_response, None) +@method_decorator(never_cache, name="dispatch") +class MetadataView(IdPHandlerViewMixin, View): + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """ Returns an XML with the SAML 2.0 metadata for this Idp. + The metadata is constructed on-the-fly based on the config dict in the django settings. + """ + metadata = self.get_idp_metadata(request) + return HttpResponse(content=metadata.encode("utf-8"), content_type="text/xml; charset=utf8",) + + @never_cache def get_multifactor(request: HttpRequest) -> HttpResponse: if hasattr(settings, "SAML_IDP_MULTIFACTOR_VIEW"): @@ -421,11 +452,3 @@ def get_multifactor(request: HttpRequest) -> HttpResponse: else: multifactor_class = ProcessMultiFactorView return multifactor_class.as_view()(request) - - -@never_cache -def metadata(request: HttpRequest) -> HttpResponse: - """ Returns an XML with the SAML 2.0 metadata for this Idp. - The metadata is constructed on-the-fly based on the config dict in the django settings. - """ - return HttpResponse(content=IDP.metadata().encode('utf-8'), content_type="text/xml; charset=utf8") diff --git a/docs/configuration.rst b/docs/configuration.rst index 8702c23..f3bcdf5 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -71,6 +71,20 @@ In your Django settings, configure your IdP. Configuration follows the `PySAML2 Notice the configuration requires a private key and public certificate to be available on the filesystem in order to sign and encrypt messages. +Dynamic IdP configuration +------------------------- + +Aditionaly a callback can be used to customize the IdP settings on a per-request basis. It can be defined either +* as a path set in the `SAML_IDP_CONFIG_LOADER` +* by subclassing the views (and using them in the url config) and overriding their `get_config_loader_path(self, request: HttpRequest)` method, returning a callback or a path to it + +Any of these callbacks will be called when loading the IdP, receiving the static configuration (defined by `SAML_IDP_CONFIG`) and the current request as arguments. It is expected to return a new configuration with the same form as the static one. + +Please note that the resulting IDP objects will be cached with the 'entityid' parameter as a key. + +Service Providers +----------------- + Next the Service Providers and their configuration need to be added, this is done via the Django admin interface. Add an entry for each SP which speaks to thie IdP. Add a copy of the local metadata xml, or set a remote metadata url. Add an attribute mapping for user attributes to SAML fields or leave the default mapping which will be prefilled. @@ -78,9 +92,11 @@ Several attributes can be overriden per SP. If they aren't overridden explicitly If those aren't set, some defaults will be used, as indicated in the admin when you configre a SP. The resulting configuration of a SP, with merged settings of its own and the instance settings and defaults, is shown in the admin as a summary. +The set of SPs available can optionnaly be dynamically defined through the `SAML_IDP_FILTER_SP_QUERYSET` setting, as a path to a callable. It receives the orginal queryset (all SPs with `active=True` field) and the current request as arguments. It is expected to return a queryset. + Further optional configuration options ====================================== - + In the ``SAML_IDP_SPCONFIG`` setting you can define a ``processor``, its value being a string with dotted path to a class. This is a hook to customize some access control checks. By default, the included `BaseProcessor` is used, which allows every user to login on the IdP. You can customize this behaviour by subclassing the `BaseProcessor` and overriding its `has_access(self, request)` method. This method should return true or false, depending if the user has permission to log in for the SP / IdP. @@ -92,7 +108,7 @@ Use this metadata xml to configure your SP. Place the metadata xml from that SP Without custom setting, users will be identified by the ``USERNAME_FIELD`` property on the user Model you use. By Django defaults this will be the username. You can customize which field is used for the identifier by adding ``SAML_IDP_DJANGO_USERNAME_FIELD`` to your settings with as value the attribute to use on your user instance. -Other settings you can set as defaults to be used if not overriden by an SP are `SAML_AUTHN_SIGN_ALG`, `SAML_AUTHN_DIGEST_ALG`, and `SAML_ENCRYPT_AUTHN_RESPONSE`. They can be set if desired in the django settings, in which case they will be used for all ServiceProviders configuration on this instance if they don't override it. E.g.:: +Other settings you can set as defaults to be used if not overriden by an SP are `SAML_AUTHN_SIGN_ALG`, `SAML_AUTHN_DIGEST_ALG`, and `SAML_ENCRYPT_AUTHN_RESPONSE`. They can be set if desired in the django settings, in which case they will be used for all ServiceProviders configuration on this instance if they don't override it. E.g.: SAML_AUTHN_SIGN_ALG = saml2.xmldsig.SIG_RSA_SHA256 SAML_AUTHN_DIGEST_ALG = saml2.xmldsig.DIGEST_SHA256 diff --git a/tests/test_conf.py b/tests/test_conf.py new file mode 100644 index 0000000..97ea163 --- /dev/null +++ b/tests/test_conf.py @@ -0,0 +1,41 @@ +import pytest +from django.core.exceptions import ImproperlyConfigured +from django.http import HttpRequest + +from djangosaml2idp.conf import get_callable, get_config +from djangosaml2idp.utils import repr_saml +from .settings import SAML_IDP_CONFIG + +class TestConf: + def test_get_callable_callable(self): + func_callable = lambda x: x + assert get_callable(func_callable) == func_callable + + def test_get_callable_path(self): + assert get_callable('djangosaml2idp.utils.repr_saml') == repr_saml + + def test_get_callable_path_unokwn(self): + with pytest.raises(ImproperlyConfigured): + get_callable('some.where.else') + + def test_get_callable_path_not_callable(self): + with pytest.raises(ImproperlyConfigured): + get_callable('djangosaml2idp.urls.app_name') + + def test_get_config_static_conf(self): + assert get_config() == SAML_IDP_CONFIG + + def test_get_config_static_conf_empty(self, settings): + settings.SAML_IDP_CONFIG = None + assert get_config() == {} + + def test_get_config(self): + request = HttpRequest() + return_value = "xxx" + def loader(c, r): + called = True + assert c == SAML_IDP_CONFIG + assert r == request + return return_value + + assert get_config(loader, request) == return_value diff --git a/tests/test_idp.py b/tests/test_idp.py index f2c4f78..49e5d4a 100644 --- a/tests/test_idp.py +++ b/tests/test_idp.py @@ -1,30 +1,75 @@ +import copy from unittest.mock import patch, Mock import pytest +from unittest import mock from django.core.exceptions import ImproperlyConfigured from saml2.server import Server from djangosaml2idp.idp import IDP +from djangosaml2idp.models import ServiceProvider +from .settings import SAML_IDP_CONFIG +def conf_loader(c, r): + return { "entityid": SAML_IDP_CONFIG["entityid"] } class TestIDP: + def teardown_method(self): + IDP.flush() @pytest.mark.django_db - def test_idp_load_default_settings_defined_and_valid(self): - IDP._server_instance = None + def test_load_default_settings_defined_and_valid(self): srv = IDP.load() assert isinstance(srv, Server) @pytest.mark.django_db - def test_idp_load_no_settings_defined(self, settings): - IDP._server_instance = None + def test_load_no_settings_defined(self, settings): settings.SAML_IDP_CONFIG = None with pytest.raises(ImproperlyConfigured): IDP.load() + @pytest.mark.django_db + def test_load_cache(self): + s1 = IDP.load() + s2 = IDP.load(config_loader_path=conf_loader) + assert s1 == s2 + + @pytest.mark.django_db + def test_load_constructsp_queryset(self, settings): + called = False + def identity_queryset(queryset, request): + nonlocal called + called = True + return queryset + + settings.SAML_IDP_FILTER_SP_QUERYSET = identity_queryset + IDP.load() + assert called + + @pytest.mark.django_db + @mock.patch('saml2.config.IdPConfig.load') + def test_construct_metadata(self, mock): + conf = { "a": 1, "b": 2 } + IDP.construct_metadata(copy.deepcopy(conf)) + mock.assert_called_with({ **conf, "metadata": { "local": [] } }) + + @pytest.mark.django_db + @mock.patch('saml2.config.IdPConfig.load') + def test_construct_metadata_raise(self, mock): + mock.side_effect = ImproperlyConfigured() + conf = { "a": 1, "b": 2 } + with pytest.raises(ImproperlyConfigured): + IDP.construct_metadata(copy.deepcopy(conf)) + + @pytest.mark.django_db + def test_flush(self): + s1 = IDP.load() + IDP.flush() + s2 = IDP.load(config_loader_path=conf_loader) + assert s1 != s2 + @pytest.mark.django_db def test_metadata_no_sp_defined_valid(self): - IDP._server_instance = None - md = IDP.metadata() + md = IDP.metadata() assert isinstance(md, str) @pytest.mark.django_db @@ -34,14 +79,12 @@ def test_metadata_sp_autoload_idp(self, sp_model_mock): sp_instance_mock = Mock() sp_instance_mock.metadata_path.return_value = '/tmp/djangosaml2idp/1.xml' sp_model_mock.objects.filter.return_value = [sp_instance_mock] - IDP._server_instance = None md = IDP.metadata() sp_instance_mock.metadata_path.assert_not_called() @pytest.mark.django_db def test_metadata_no_settings_defined(self, settings): - IDP._server_instance = None settings.SAML_IDP_CONFIG = None with pytest.raises(ImproperlyConfigured): - IDP.metadata() + IDP.metadata() diff --git a/tests/test_models.py b/tests/test_models.py index 98c5f06..ce258c4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,6 +8,7 @@ from saml2 import xmldsig import requests_mock +from djangosaml2idp.conf import get_config from djangosaml2idp.idp import IDP from djangosaml2idp.models import DEFAULT_ATTRIBUTE_MAPPING, ServiceProvider @@ -32,31 +33,22 @@ def test_property_attribute_mapping(self): assert instance.attribute_mapping == DEFAULT_ATTRIBUTE_MAPPING instance = ServiceProvider(_attribute_mapping='{"custom_key": "custom_value"}') assert instance.attribute_mapping == {"custom_key": "custom_value"} - - def test_property_sign_response(self): - instance = ServiceProvider(_sign_response=None) - assert instance.sign_response == getattr(IDP.load().config, "sign_response", False) - instance = ServiceProvider(_sign_response=True) - assert instance.sign_response == True - - def test_property_sign_assertion(self): - instance = ServiceProvider(_sign_assertion=None) - assert instance.sign_assertion == getattr(IDP.load().config, "sign_assertion", False) - instance = ServiceProvider(_sign_assertion=True) - assert instance.sign_assertion == True - + + @pytest.mark.django_db def test_property_encrypt_saml_responses(self): instance = ServiceProvider(_encrypt_saml_responses=None) assert instance.encrypt_saml_responses == getattr(IDP.load().config, "SAML_ENCRYPT_AUTHN_RESPONSE", False) instance = ServiceProvider(_encrypt_saml_responses=True) assert instance.encrypt_saml_responses == True - + + @pytest.mark.django_db def test_property_signing_algorithm(self): instance = ServiceProvider(_signing_algorithm=None) assert instance.signing_algorithm == getattr(IDP.load().config, "SAML_AUTHN_SIGN_ALG", xmldsig.SIG_RSA_SHA256) instance = ServiceProvider(_signing_algorithm='dummy_value') assert instance.signing_algorithm == 'dummy_value' - + + @pytest.mark.django_db def test_property_digest_algorithm(self): instance = ServiceProvider(_digest_algorithm=None) assert instance.digest_algorithm == getattr(IDP.load().config, "SAML_AUTHN_DIGEST_ALG", xmldsig.DIGEST_SHA256) diff --git a/tests/test_processor.py b/tests/test_processor.py index b12320f..200b862 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -7,6 +7,7 @@ from django.http import HttpRequest from saml2.saml import NAMEID_FORMAT_UNSPECIFIED +from djangosaml2idp.conf import get_config from djangosaml2idp.idp import IDP from djangosaml2idp.models import ServiceProvider from djangosaml2idp.processors import (BaseProcessor, NameIdBuilder, diff --git a/tests/test_views.py b/tests/test_views.py index 20456dc..ac45bd0 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -4,8 +4,8 @@ import pytest from django.contrib.auth import authenticate, get_user_model, login, logout from django.contrib.sessions.backends.db import SessionStore -from django.core.exceptions import (ImproperlyConfigured, PermissionDenied, - ValidationError) +from django.core.exceptions import (ImproperlyConfigured, ObjectDoesNotExist, + PermissionDenied, ValidationError) from django.http import HttpRequest, HttpResponse, HttpResponseRedirect from django.template.exceptions import TemplateSyntaxError from django.utils import timezone @@ -15,15 +15,17 @@ from saml2.saml import NAMEID_FORMAT_X509SUBJECTNAME from saml2.samlp import Response +from djangosaml2idp.conf import get_config +from djangosaml2idp.idp import IDP from djangosaml2idp.models import ServiceProvider from djangosaml2idp.processors import BaseProcessor from djangosaml2idp.utils import encode_saml from djangosaml2idp.views import (BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, IdPHandlerViewMixin, LoginProcessView, - LogoutProcessView, ProcessMultiFactorView, + LogoutProcessView, MetadataView, ProcessMultiFactorView, SSOInitView, build_authn_response, check_access, get_authn, get_multifactor, - get_sp_config, metadata, sso_entry, + get_sp_config, sso_entry, store_params_in_session) User = get_user_model() @@ -225,6 +227,9 @@ def mock_get_template(mocker): class TestIdPHandlerViewMixin: + def teardown_method(self): + IDP.flush() + def test_render_login_hto_to_string_returns_result_of_render(self, mock_get_template): mixin = IdPHandlerViewMixin() @@ -288,39 +293,42 @@ def test_fetch_custom_template_returns_default_if_syntax_error(self, mock_get_te @pytest.mark.django_db def test_set_sp_errors_if_sp_not_defined(self): - with pytest.raises(ImproperlyConfigured): - get_sp_config('this_sp_does_not_exist') + with pytest.raises(ObjectDoesNotExist): + get_sp_config('this_sp_does_not_exist', IDP.load()) @pytest.mark.django_db def test_set_sp_works_if_sp_defined(self, settings, sp_metadata_xml, sp_testing_configs): ServiceProvider.objects.create(entity_id='test_generic_sp', local_metadata=sp_metadata_xml) - sp = get_sp_config('test_generic_sp') + sp = get_sp_config('test_generic_sp', IDP.load()) assert sp._processor == sp_testing_configs['test_generic_sp']['processor'] assert sp.attribute_mapping == sp_testing_configs['test_generic_sp']['attribute_mapping'] @pytest.mark.django_db + @pytest.mark.parametrize('sp_metadata_xml', ['sp_with_bad_processor_metadata'], indirect=True) def test_set_processor_errors_if_processor_cannot_be_loaded(self, sp_metadata_xml): - ServiceProvider.objects.create(entity_id='test_sp_with_bad_processor', local_metadata=sp_metadata_xml, _processor='this.does.not.exist') - sp = get_sp_config('test_sp_with_bad_processor') + ServiceProvider.objects.create(entity_id='test_sp_with_bad_processor', active=True, local_metadata=sp_metadata_xml, _processor='this.does.not.exist') + sp = get_sp_config('test_sp_with_bad_processor', IDP.load()) with pytest.raises(Exception): _ = sp.processor @pytest.mark.django_db + @pytest.mark.parametrize('sp_metadata_xml', ['sp_with_no_processor_metadata'], indirect=True) def test_set_processor_defaults_to_base_processor(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_sp_with_no_processor', local_metadata=sp_metadata_xml, _attribute_mapping='{}') - sp = get_sp_config('test_sp_with_no_processor') + sp = get_sp_config('test_sp_with_no_processor', IDP.load()) assert isinstance(sp.processor, BaseProcessor) @pytest.mark.django_db + @pytest.mark.parametrize('sp_metadata_xml', ['sp_with_custom_processor_metadata'], indirect=True) def test_get_processor_loads_custom_processor(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_sp_with_custom_processor', local_metadata=sp_metadata_xml, _processor='tests.test_views.CustomProcessor') - sp = get_sp_config('test_sp_with_custom_processor') + sp = get_sp_config('test_sp_with_custom_processor', IDP.load()) assert isinstance(sp.processor, CustomProcessor) @@ -336,15 +344,16 @@ def test_get_authn_returns_correctly_when_no_req_info(self): def test_check_access_works(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_generic_sp', local_metadata=sp_metadata_xml) - sp = get_sp_config('test_generic_sp') + sp = get_sp_config('test_generic_sp', IDP.load()) processor = sp.processor check_access(processor, HttpRequest()) @pytest.mark.django_db + @pytest.mark.parametrize('sp_metadata_xml', ['sp_with_custom_processor_that_doesnt_allow_access_metadata'], indirect=True) def test_check_access_fails_when_it_should(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_sp_with_custom_processor_that_doesnt_allow_access', local_metadata=sp_metadata_xml, _processor='tests.test_views.CustomProcessorNoAccess') - sp = get_sp_config('test_sp_with_custom_processor_that_doesnt_allow_access') + sp = get_sp_config('test_sp_with_custom_processor_that_doesnt_allow_access', IDP.load()) processor = sp.processor with pytest.raises(PermissionDenied): check_access(processor, HttpRequest()) @@ -353,20 +362,22 @@ def test_check_access_fails_when_it_should(self, sp_metadata_xml): def test_build_authn_response(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_generic_sp', local_metadata=sp_metadata_xml) - sp = get_sp_config('test_generic_sp') + sp = get_sp_config('test_generic_sp', IDP.load()) + idp = IDP.load() user = User() authn = get_authn() resp_args = { "in_response_to": "SP_Initiated_Login", "destination": "https://sp.example.com/SAML2", } - assert isinstance(build_authn_response(user, authn, resp_args, sp), Response) + assert isinstance(build_authn_response(user, authn, resp_args, sp, idp), Response) @pytest.mark.django_db def test_build_authn_response_unsupported_nameidformat(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_generic_sp', local_metadata=sp_metadata_xml) - sp = get_sp_config('test_generic_sp') + sp = get_sp_config('test_generic_sp', IDP.load()) + idp = IDP.load() authn = get_authn() resp_args = { "in_response_to": "SP_Initiated_Login", @@ -375,7 +386,7 @@ def test_build_authn_response_unsupported_nameidformat(self, sp_metadata_xml): } with pytest.raises(ImproperlyConfigured): - build_authn_response(User(), authn, resp_args, sp) + build_authn_response(User(), authn, resp_args, sp, idp) @pytest.mark.django_db def test_create_html_response_with_post(self): @@ -404,7 +415,7 @@ def compile_data_for_render_response(self, sp_metadata_xml): ServiceProvider.objects.create(entity_id='test_generic_sp', local_metadata=sp_metadata_xml) mixin = IdPHandlerViewMixin() - _ = get_sp_config("test_generic_sp") + _ = get_sp_config("test_generic_sp", IDP.load()) user = User.objects.create() user.email = "test@gmail.com", @@ -442,7 +453,7 @@ def test_render_response_constructs_request_session_properly(self, sp_metadata_x "saml_data": html_response } - mixin.render_response(request, html_response, get_sp_config('test_generic_sp').processor) + mixin.render_response(request, html_response, get_sp_config('test_generic_sp', IDP.load()).processor) assert all(item in request.session.items() for item in expected_session.items()) @@ -454,7 +465,7 @@ def multifactor(self, user): return True # Bind enable_multifactor being true to mixin processor. - processor = get_sp_config('test_generic_sp').processor + processor = get_sp_config('test_generic_sp', IDP.load()).processor processor.enable_multifactor = multifactor.__get__(processor) response = mixin.render_response(request, html_response, processor) assert isinstance(response, HttpResponseRedirect) @@ -594,8 +605,10 @@ def test_slo_view_works_properly_redirect(self, sp_metadata_xml, logged_in_reque class TestMetadata: @pytest.mark.django_db - def test_metadata_works_correctly(self): - response = metadata(HttpRequest()) + def test_metadata_works_correctly(self, settings): + request = HttpRequest() + request.method = "GET" + response = MetadataView.as_view()(request) assert isinstance(response, HttpResponse) assert response.charset == 'utf8' assert response.status_code == 200 diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py index ede867f..9bbb364 100644 --- a/tests/testing_utilities.py +++ b/tests/testing_utilities.py @@ -9,3 +9,6 @@ def __init__(self, text, status_code): return MockResponse('not found', 404) return MockResponse('ok', 200) + +def identity_queryset(queryset): + return queryset diff --git a/tests/xml/metadata/sp_with_bad_processor_metadata.xml b/tests/xml/metadata/sp_with_bad_processor_metadata.xml new file mode 100644 index 0000000..91f5a5d --- /dev/null +++ b/tests/xml/metadata/sp_with_bad_processor_metadata.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + +urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + + \ No newline at end of file diff --git a/tests/xml/metadata/sp_with_custom_processor_metadata.xml b/tests/xml/metadata/sp_with_custom_processor_metadata.xml new file mode 100644 index 0000000..ad27ff3 --- /dev/null +++ b/tests/xml/metadata/sp_with_custom_processor_metadata.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + +urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + + \ No newline at end of file diff --git a/tests/xml/metadata/sp_with_custom_processor_that_doesnt_allow_access_metadata.xml b/tests/xml/metadata/sp_with_custom_processor_that_doesnt_allow_access_metadata.xml new file mode 100644 index 0000000..f37ac91 --- /dev/null +++ b/tests/xml/metadata/sp_with_custom_processor_that_doesnt_allow_access_metadata.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + +urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + + \ No newline at end of file diff --git a/tests/xml/metadata/sp_with_no_processor_metadata.xml b/tests/xml/metadata/sp_with_no_processor_metadata.xml new file mode 100644 index 0000000..cfba989 --- /dev/null +++ b/tests/xml/metadata/sp_with_no_processor_metadata.xml @@ -0,0 +1,31 @@ + + + + + + + + + + + + + + + + + + + + + + + +urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + + \ No newline at end of file