diff --git a/djangosaml2idp/admin.py b/djangosaml2idp/admin.py index 1ffa5b1..b89d554 100644 --- a/djangosaml2idp/admin.py +++ b/djangosaml2idp/admin.py @@ -8,7 +8,7 @@ class ServiceProviderAdmin(admin.ModelAdmin): list_filter = ['active', '_sign_response', '_sign_assertion', '_signing_algorithm', '_digest_algorithm', '_encrypt_saml_responses'] list_display = ['__str__', 'active', 'description'] - readonly_fields = ('dt_created', 'dt_updated', 'resulting_config', 'metadata_expiration_dt') + readonly_fields = ('entity_id', 'dt_created', 'dt_updated', 'resulting_config', 'metadata_expiration_dt', 'cache_expiration_dt') form = ServiceProviderAdminForm fieldsets = ( @@ -16,7 +16,7 @@ class ServiceProviderAdmin(admin.ModelAdmin): 'fields': ('entity_id', 'pretty_name', 'description') }), ('Metadata', { - 'fields': ('metadata_expiration_dt', 'remote_metadata_url', 'local_metadata') + 'fields': ('metadata_expiration_dt', 'cache_expiration_dt', 'remote_metadata_url', 'local_metadata') }), ('Configuration', { 'fields': ('active', '_processor', '_attribute_mapping', '_nameid_field', '_sign_response', '_sign_assertion', '_signing_algorithm', '_digest_algorithm', '_encrypt_saml_responses'), diff --git a/djangosaml2idp/forms.py b/djangosaml2idp/forms.py index 37cbcb4..ea624c1 100644 --- a/djangosaml2idp/forms.py +++ b/djangosaml2idp/forms.py @@ -6,7 +6,6 @@ from .models import ServiceProvider from .processors import instantiate_processor, validate_processor_path -from .utils import validate_metadata boolean_form_select_choices = ((None, _('--------')), (True, _('Yes')), (False, _('No'))) @@ -40,27 +39,23 @@ def clean__processor(self): validate_processor_path(value) return value - def clean_local_metadata(self): - value = self.cleaned_data['local_metadata'] - validate_metadata(value) - return value - def clean(self): cleaned_data = super().clean() if not (cleaned_data.get('remote_metadata_url') or cleaned_data.get('local_metadata')): raise ValidationError('Either a remote metadata URL, or a local metadata xml needs to be provided.') + # Call the validation methods to catch ValidationErrors here, so they get displayed cleanly in the admin UI + self.instance.local_metadata = cleaned_data.get('local_metadata') + self.instance.remote_metadata_url = cleaned_data.get('remote_metadata_url') + _, updated_fields = self.instance.load_metadata(force_refresh=True) + + for key in updated_fields: + cleaned_data[key] = updated_fields[key] + if '_processor' in cleaned_data: processor_path = cleaned_data['_processor'] entity_id = cleaned_data['entity_id'] processor_cls = validate_processor_path(processor_path) instantiate_processor(processor_cls, entity_id) - - self.instance.local_metadata = cleaned_data.get('local_metadata') - # Call the validation methods to catch ValidationErrors here, so they get displayed cleanly in the admin UI - if cleaned_data.get('remote_metadata_url'): - self.instance.remote_metadata_url = cleaned_data.get('remote_metadata_url') - cleaned_data['local_metadata'] = self.instance.local_metadata - self.instance.refresh_metadata(force_refresh=True) diff --git a/djangosaml2idp/migrations/0003_auto_20200503_1134.py b/djangosaml2idp/migrations/0003_auto_20200503_1134.py new file mode 100644 index 0000000..2e23ffb --- /dev/null +++ b/djangosaml2idp/migrations/0003_auto_20200503_1134.py @@ -0,0 +1,23 @@ +# Generated by Django 3.0.5 on 2020-05-03 11:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('djangosaml2idp', '0002_persistent_id'), + ] + + operations = [ + migrations.AddField( + model_name='serviceprovider', + name='cache_expiration_dt', + field=models.DateTimeField(blank=True, null=True, verbose_name='Cache metadata until'), + ), + migrations.AlterField( + model_name='serviceprovider', + name='metadata_expiration_dt', + field=models.DateTimeField(blank=True, null=True, verbose_name='Metadata valid until'), + ), + ] diff --git a/djangosaml2idp/models.py b/djangosaml2idp/models.py index 75def1d..0dace0f 100644 --- a/djangosaml2idp/models.py +++ b/djangosaml2idp/models.py @@ -1,9 +1,10 @@ import datetime import json import logging -import uuid import os -from typing import Dict +import uuid +import xml.etree.ElementTree as ET +from typing import TYPE_CHECKING, Dict, Tuple import pytz from django.conf import settings @@ -15,10 +16,10 @@ from saml2 import xmldsig from .idp import IDP -from .utils import (extract_validuntil_from_metadata, fetch_metadata, +from .utils import (extract_cacheduration_from_metadata, + extract_validuntil_from_metadata, fetch_metadata, validate_metadata) -from typing import TYPE_CHECKING if TYPE_CHECKING: from .processors import BaseProcessor @@ -38,16 +39,17 @@ class ServiceProvider(models.Model): # Bookkeeping - dt_created = models.DateTimeField(verbose_name='Created at', auto_now_add=True) - dt_updated = models.DateTimeField(verbose_name='Updated at', auto_now=True, null=True, blank=True) + dt_created = models.DateTimeField(verbose_name='Created at', auto_now_add=True, help_text='UTC') + dt_updated = models.DateTimeField(verbose_name='Updated at', auto_now=True, null=True, blank=True, help_text='UTC') # Identification - entity_id = models.CharField(verbose_name='Entity ID', max_length=255, unique=True) + entity_id = models.CharField(verbose_name='Entity ID', max_length=255, blank=True, unique=True, help_text='Automatically extracted from the metadata') pretty_name = models.CharField(verbose_name='Pretty Name', blank=True, max_length=255, help_text='For display purposes, can be empty') description = models.TextField(verbose_name='Description', blank=True) # Metadata - metadata_expiration_dt = models.DateTimeField(verbose_name='Metadata valid until') + metadata_expiration_dt = models.DateTimeField(verbose_name='Metadata valid until', blank=True, null=True, help_text='UTC') + cache_expiration_dt = models.DateTimeField(verbose_name='Cache metadata until', blank=True, null=True, help_text='UTC') remote_metadata_url = models.CharField(verbose_name='Remote metadata URL', max_length=512, blank=True, help_text='If set, metadata will be fetched upon saving into the local metadata xml field, and automatically be refreshed after the expiration timestamp.') local_metadata = models.TextField(verbose_name='Local Metadata XML', blank=True, help_text='XML containing the metadata') @@ -63,65 +65,98 @@ def field_value_changed(self, field_name: str) -> bool: return current_value != getattr(self, '_loaded_db_values', {}).get(field_name, current_value) def _should_refresh(self) -> bool: - ''' Returns whether or not a refresh operation is necessary. - ''' + ''' Returns whether or not a refresh operation is necessary. ''' # - Data was not fetched ever before, so local_metadata is empty, or local_metadata has been changed from what it was in the db before if not self.local_metadata or self.field_value_changed('local_metadata'): return True # - The remote url has been updated if self.field_value_changed('remote_metadata_url'): return True + # - The cache duration is set ... + if self.cache_expiration_dt: + # and it has been expired + if now() > self.cache_expiration_dt: + return True + # it hasn't been expired yet + return False # - The expiration timestamp is not set, or it is expired if not self.metadata_expiration_dt or now() > self.metadata_expiration_dt: return True - + # Everything is still valid, no refresh necessary return False - def _refresh_from_remote(self) -> bool: + def _load_from_remote(self) -> Tuple[bool, dict]: + updated_fields = {} try: self.local_metadata = validate_metadata(fetch_metadata(self.remote_metadata_url)) - self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata).replace(tzinfo=None) + # Try to extract the entityID + self.entity_id = ET.fromstring(self.local_metadata).attrib['entityID'] + # Try to extract a valid expiration datetime + self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata) + self.cache_expiration_dt = extract_cacheduration_from_metadata(self.local_metadata) + + updated_fields = { + 'entity_id': self.entity_id, + 'metadata_expiration_dt': self.metadata_expiration_dt, + 'cache_expiration_dt': self.cache_expiration_dt, + 'local_metadata': self.local_metadata, + } + # Return True if it is now valid, False (+ log an error) otherwise - if now() > self.metadata_expiration_dt: + if self.metadata_expiration_dt and now() > self.metadata_expiration_dt: logger.error(f'Remote metadata for SP {self.entity_id} was refreshed, but contains an expired validity datetime.') - return False - return True + return False, updated_fields + + return True, updated_fields except Exception as e: logger.error(f'Metadata for SP {self.entity_id} could not be pulled from remote url {self.remote_metadata_url}.', extra={'exception': str(e)}) - return False + return False, {} - def _refresh_from_local(self) -> bool: + def _load_from_local(self) -> bool: try: + self.local_metadata = validate_metadata(self.local_metadata) + # Try to extract the entityID + self.entity_id = ET.fromstring(self.local_metadata).attrib['entityID'] # Try to extract a valid expiration datetime from the local metadata - self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata).replace(tzinfo=None) + self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata) + self.cache_expiration_dt = extract_cacheduration_from_metadata(self.local_metadata) + + updated_fields = { + 'entity_id': self.entity_id, + 'metadata_expiration_dt': self.metadata_expiration_dt, + 'cache_expiration_dt': self.cache_expiration_dt, + 'local_metadata': self.local_metadata, + } + # Return True if it is now valid, False (+ log an error) otherwise if now() > self.metadata_expiration_dt: logger.error(f'Local metadata for SP {self.entity_id} contains an expired validity datetime or none at all, no remote metadata found to refresh.') - return False - return True + return False, updated_fields + + return True, updated_fields except Exception as e: # Could not extract a valid expiry timestamp, return False (+ log an error) logger.error(f'Metadata expiration dt for SP {self.entity_id} could not be extracted from local metadata.', extra={'exception': str(e)}) - return False + return False, updated_fields - def refresh_metadata(self, force_refresh: bool = False) -> bool: + def load_metadata(self, force_refresh: bool = False) -> Tuple[bool, dict]: ''' If a remote metadata url is set, fetch new metadata if the locally cached one is expired. Returns True if new metadata was set. Sets metadata fields on instance, but does not save to db. If force_refresh = True, the metadata will be refreshed regardless of the currently cached version validity timestamp. ''' - if not self._should_refresh() and not force_refresh: - return False + if not (self._should_refresh() or force_refresh): + return False, {} - if not self.remote_metadata_url and not self.local_metadata: + if not (self.remote_metadata_url or self.local_metadata): logger.error(f'Local metadata for SP {self.entity_id} is not present, and no remote metadata found to refresh.') - return False + return False, {} if self.remote_metadata_url: - return self._refresh_from_remote() + return self._load_from_remote() if force_refresh or (not self.metadata_expiration_dt) or (now() > self.metadata_expiration_dt) or self.field_value_changed('local_metadata'): - return self._refresh_from_local() + return self._load_from_local() - raise Exception('Uncaught case of refresh_metadata') + raise Exception('Uncaught case of load_metadata') # Configuration active = models.BooleanField(verbose_name='Active', default=True) @@ -184,7 +219,7 @@ def metadata_path(self) -> str: Return the location of that file. """ # On access, update the metadata if necessary - refreshed_metadata = self.refresh_metadata() + refreshed_metadata, _ = self.load_metadata() if refreshed_metadata: self.save() diff --git a/djangosaml2idp/utils.py b/djangosaml2idp/utils.py index 1f634ec..1e7d8ba 100644 --- a/djangosaml2idp/utils.py +++ b/djangosaml2idp/utils.py @@ -1,14 +1,17 @@ import base64 import datetime import xml.dom.minidom -from saml2.response import StatusResponse import xml.etree.ElementTree as ET import zlib from xml.parsers.expat import ExpatError -from django.utils.translation import gettext as _ + import arrow +import isodate +import pytz import requests from django.core.exceptions import ValidationError +from django.utils.translation import gettext as _ +from saml2.response import StatusResponse def repr_saml(saml: str, b64: bool = False): @@ -61,11 +64,29 @@ def validate_metadata(metadata: str) -> str: def extract_validuntil_from_metadata(metadata: str) -> datetime.datetime: - ''' Extract the ValidUntil timestamp from the given metadata. Returns that timestamp if successfully, raise a ValidationError otherwise. - ''' - try: - metadata_expiration_dt = arrow.get(ET.fromstring(metadata).attrib['validUntil']).datetime - except Exception as e: - raise ValidationError(f'Could not extra ValidUntil timestamp from metadata: {e}') + ''' Extract the expiration timestamp from the given metadata. Returns a timestamp if successfully, raise a ValidationError otherwise. ''' + + metadata_el = ET.fromstring(metadata) + metadata_expiration_dt = None + if 'validUntil' in metadata_el.attrib: + try: + metadata_expiration_dt = arrow.get(metadata_el.attrib['validUntil']).datetime.replace(tzinfo=pytz.utc) + except Exception as e: + raise ValidationError(f'Error extracting ValidUntil timestamp from metadata: {e}') return metadata_expiration_dt + + +def extract_cacheduration_from_metadata(metadata: str) -> datetime.datetime: + ''' Extract the cache duration expiration timestamp from the given metadata. Returns a timestamp if successfully, raise a ValidationError otherwise. ''' + + metadata_el = ET.fromstring(metadata) + + cache_expiration_dt = None + if 'cacheDuration' in metadata_el.attrib: + try: + time_delta = isodate.parse_duration(metadata_el.attrib['cacheDuration']) + cache_expiration_dt = (arrow.get() + time_delta).datetime.replace(tzinfo=pytz.utc) + except Exception as e: + raise ValidationError(f'Error extracting cacheDuration from metadata: {e}') + return cache_expiration_dt diff --git a/requirements-dev.in b/requirements-dev.in index 2d893ac..7bed088 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -6,3 +6,4 @@ tox pre-commit pytest-cov pytest-django +isodate diff --git a/requirements-dev.txt b/requirements-dev.txt index de08152..4e849a6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,6 +20,7 @@ filelock==3.0.12 # via tox, virtualenv identify==1.4.14 # via pre-commit idna==2.9 # via requests importlib-metadata==1.6.0 # via pluggy, pre-commit, pytest, tox, virtualenv +isodate==0.6.0 mccabe==0.6.1 # via pylama more-itertools==8.2.0 # via pytest nodeenv==1.3.5 # via pre-commit @@ -43,7 +44,7 @@ python-dateutil==2.8.1 # via pysaml2 pytz==2019.3 # via pysaml2 pyyaml==5.3.1 # via pre-commit requests==2.23.0 # via codecov, pysaml2 -six==1.14.0 # via cryptography, packaging, pip-tools, pyopenssl, pysaml2, python-dateutil, tox, virtualenv +six==1.14.0 # via cryptography, isodate, packaging, pip-tools, pyopenssl, pysaml2, python-dateutil, tox, virtualenv snowballstemmer==2.0.0 # via pydocstyle toml==0.10.0 # via pre-commit, tox tox==3.14.6 diff --git a/setup.py b/setup.py index 0f70b4a..966cee7 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ 'pysaml2>=5.0.0', 'pytz', 'arrow', + 'isodate', ], extras_require={ "testing": [ diff --git a/tests/test_models.py b/tests/test_models.py index 98c5f06..97768e3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -91,7 +91,7 @@ def test_refresh_meta_data_returns_false_on_model_state(self): local_metadata=timezone.now(), metadata_expiration_dt=timezone.now() + timedelta(hours=1), ) - assert instance.refresh_metadata() is False + assert instance.load_metadata()[0] is False @pytest.mark.django_db def test_should_refresh_on_changed_local_metadata(self, sp_metadata_xml): @@ -135,9 +135,9 @@ def test_refresh_meta_data_succesful_returns_true_on_model_state(self, instance) if instance.remote_metadata_url: with requests_mock.mock() as m: m.get(instance.remote_metadata_url, text=VALID_XML) - refreshed = instance.refresh_metadata() + refreshed, _ = instance.load_metadata() else: - refreshed = instance.refresh_metadata() + refreshed, _ = instance.load_metadata() assert refreshed @@ -178,13 +178,13 @@ def test_refresh_meta_data_failure_returns_false_on_model_state(self, instance): if instance.remote_metadata_url == "http://expired_remote": with requests_mock.mock() as m: m.get(instance.remote_metadata_url, text=EXPIRED_XML) - refreshed = instance.refresh_metadata() + refreshed, _ = instance.load_metadata() if instance.remote_metadata_url == "http://not_found": with requests_mock.mock() as m: m.get(instance.remote_metadata_url, text='Notfound') - refreshed = instance.refresh_metadata() + refreshed, _ = instance.load_metadata() else: - refreshed = instance.refresh_metadata() + refreshed, _ = instance.load_metadata() assert not refreshed @@ -197,19 +197,19 @@ def test_refresh_meta_data_returns_true_on_force_refresh(self): with requests_mock.mock() as m: m.get(sp.remote_metadata_url, text=VALID_XML) - refreshed = sp.refresh_metadata(True) + refreshed, _ = sp.load_metadata(True) assert refreshed assert sp.local_metadata == VALID_XML - def test_refresh_metadata_updates_metadata_expiration_dt_from_remote(self): + def test_load_metadata_updates_metadata_expiration_dt_from_remote(self): sp = ServiceProvider( metadata_expiration_dt=timezone.now(), remote_metadata_url="http://someremote", ) with requests_mock.mock() as m: m.get(sp.remote_metadata_url, text=VALID_XML) - refreshed = sp.refresh_metadata() + refreshed, _ = sp.load_metadata() assert refreshed assert sp.local_metadata == VALID_XML