From 5864ede7cb3eea95ed88a68b5fc551dc367b40e9 Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Mon, 18 May 2026 20:23:57 -0400 Subject: [PATCH 1/8] Support arbitrary, temporary Entitlement grants Fixes #666 --- squarelet/organizations/admin.py | 35 +++ .../0062_alter_plan_slug_entitlementgrant.py | 110 +++++++ .../organizations/models/organization.py | 9 +- squarelet/organizations/models/payment.py | 81 ++++++ squarelet/organizations/querysets.py | 37 +++ squarelet/organizations/serializers.py | 31 +- squarelet/organizations/tests/factories.py | 20 ++ .../tests/models/test_entitlement_grant.py | 269 ++++++++++++++++++ .../tests/models/test_organization.py | 2 +- .../organizations/tests/test_serializers.py | 72 +++++ 10 files changed, 654 insertions(+), 12 deletions(-) create mode 100644 squarelet/organizations/migrations/0062_alter_plan_slug_entitlementgrant.py create mode 100644 squarelet/organizations/tests/models/test_entitlement_grant.py create mode 100644 squarelet/organizations/tests/test_serializers.py diff --git a/squarelet/organizations/admin.py b/squarelet/organizations/admin.py index 4a1b87343..6aaf72539 100644 --- a/squarelet/organizations/admin.py +++ b/squarelet/organizations/admin.py @@ -24,6 +24,7 @@ Charge, Customer, Entitlement, + EntitlementGrant, Invitation, Invoice, Membership, @@ -516,6 +517,40 @@ class EntitlementAdmin(VersionAdmin): autocomplete_fields = ("client",) +@admin.register(EntitlementGrant) +class EntitlementGrantAdmin(VersionAdmin): + list_display = ( + "name", + "active", + "for_individuals", + "for_groups", + "require_verified", + "require_active_subscription", + ) + list_filter = ( + "active", + "for_individuals", + "for_groups", + "require_verified", + "require_active_subscription", + ) + search_fields = ("name", "description") + autocomplete_fields = ("entitlements", "organizations") + filter_horizontal = ("entitlements", "organizations") + fieldsets = ( + (None, {"fields": ("name", "description", "active", "entitlements")}), + ( + "Eligible organization types", + {"fields": ("for_individuals", "for_groups")}, + ), + ("Explicit grants", {"fields": ("organizations",)}), + ( + "Rule-based grants", + {"fields": ("require_verified", "require_active_subscription")}, + ), + ) + + def make_metadata_filter(field): """Make a dynamic filter class""" diff --git a/squarelet/organizations/migrations/0062_alter_plan_slug_entitlementgrant.py b/squarelet/organizations/migrations/0062_alter_plan_slug_entitlementgrant.py new file mode 100644 index 000000000..38414b58d --- /dev/null +++ b/squarelet/organizations/migrations/0062_alter_plan_slug_entitlementgrant.py @@ -0,0 +1,110 @@ +# Generated by Django 5.2.12 on 2026-05-19 00:18 + +import autoslug.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("organizations", "0061_merge_20260324"), + ] + + operations = [ + migrations.AlterField( + model_name="plan", + name="slug", + field=autoslug.fields.AutoSlugField( + editable=True, + help_text="A unique slug to identify the plan", + populate_from="name", + unique=True, + verbose_name="slug", + ), + ), + migrations.CreateModel( + name="EntitlementGrant", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255, verbose_name="name")), + ( + "description", + models.TextField( + blank=True, default="", verbose_name="description" + ), + ), + ( + "require_verified", + models.BooleanField( + default=False, + help_text="Match organizations whose verified_journalist=True", + verbose_name="require verified", + ), + ), + ( + "require_active_subscription", + models.BooleanField( + default=False, + help_text="Match organizations with at least one active subscription", + verbose_name="require active subscription", + ), + ), + ( + "for_individuals", + models.BooleanField( + default=True, + help_text="Apply this grant to individual organizations", + verbose_name="for individuals", + ), + ), + ( + "for_groups", + models.BooleanField( + default=True, + help_text="Apply this grant to non-individual organizations", + verbose_name="for groups", + ), + ), + ( + "active", + models.BooleanField( + default=True, + help_text="Inactive grants do not apply to any organization", + verbose_name="active", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "entitlements", + models.ManyToManyField( + help_text="Entitlements this grant extends", + related_name="grants", + to="organizations.entitlement", + verbose_name="entitlements", + ), + ), + ( + "organizations", + models.ManyToManyField( + blank=True, + help_text="Organizations explicitly granted these entitlements", + related_name="entitlement_grants", + to="organizations.organization", + verbose_name="organizations", + ), + ), + ], + options={ + "ordering": ("-created_at", "name"), + }, + ), + ] diff --git a/squarelet/organizations/models/organization.py b/squarelet/organizations/models/organization.py index 06e179e3c..cb9499117 100644 --- a/squarelet/organizations/models/organization.py +++ b/squarelet/organizations/models/organization.py @@ -847,7 +847,14 @@ def merge(self, org, user): if self.parent is None: self.parent = org.parent - m2m_relations = ["private_plans", "children", "groups", "members", "subtypes"] + m2m_relations = [ + "private_plans", + "children", + "groups", + "members", + "subtypes", + "entitlement_grants", + ] for m2m in m2m_relations: getattr(self, m2m).add(*getattr(org, m2m).all()) getattr(org, m2m).clear() diff --git a/squarelet/organizations/models/payment.py b/squarelet/organizations/models/payment.py index fcf4c4cd2..bebe8cb3c 100644 --- a/squarelet/organizations/models/payment.py +++ b/squarelet/organizations/models/payment.py @@ -19,6 +19,7 @@ from squarelet.organizations.payments.factory import get_payment_provider from squarelet.organizations.querysets import ( ChargeQuerySet, + EntitlementGrantQuerySet, EntitlementQuerySet, PlanQuerySet, SubscriptionQuerySet, @@ -762,6 +763,86 @@ def public(self): return self.plans.filter(public=True).exists() +class EntitlementGrant(models.Model): + """Grants Entitlements to organizations, explicitly or by rule.""" + + name = models.CharField(_("name"), max_length=255) + description = models.TextField(_("description"), blank=True, default="") + + entitlements = models.ManyToManyField( + verbose_name=_("entitlements"), + to="organizations.Entitlement", + related_name="grants", + help_text=_("Entitlements this grant extends"), + ) + organizations = models.ManyToManyField( + verbose_name=_("organizations"), + to="organizations.Organization", + related_name="entitlement_grants", + blank=True, + help_text=_("Organizations explicitly granted these entitlements"), + ) + + require_verified = models.BooleanField( + _("require verified"), + default=False, + help_text=_("Match organizations whose verified_journalist=True"), + ) + require_active_subscription = models.BooleanField( + _("require active subscription"), + default=False, + help_text=_("Match organizations with at least one active subscription"), + ) + + for_individuals = models.BooleanField( + _("for individuals"), + default=True, + help_text=_("Apply this grant to individual organizations"), + ) + for_groups = models.BooleanField( + _("for groups"), + default=True, + help_text=_("Apply this grant to non-individual organizations"), + ) + + active = models.BooleanField( + _("active"), + default=True, + help_text=_("Inactive grants do not apply to any organization"), + ) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + objects = EntitlementGrantQuerySet.as_manager() + + class Meta: + ordering = ("-created_at", "name") + + def __str__(self): + return self.name + + def matches(self, org): + if not self.active: + return False + # Org-type filter applies to both explicit and rule-based matches. + if org.individual and not self.for_individuals: + return False + if not org.individual and not self.for_groups: + return False + # Uses `.all()` so a prefetched `organizations` relation is reused. + if any(o.pk == org.pk for o in self.organizations.all()): + return True + checks = [] + if self.require_verified: + checks.append(bool(org.verified_journalist)) + if self.require_active_subscription: + checks.append(org.has_active_subscription()) + if not checks: + return False + return all(checks) + + class ReceiptEmail(models.Model): """An email address to send receipts to""" diff --git a/squarelet/organizations/querysets.py b/squarelet/organizations/querysets.py index e6306d468..caf0a7cb7 100644 --- a/squarelet/organizations/querysets.py +++ b/squarelet/organizations/querysets.py @@ -164,6 +164,43 @@ def get_owned(self, user): else: return self.none() + def for_organization(self, org, client=None): + """Return the deduped union of plan-derived and grant-derived entitlements + currently available to `org`. Optionally scoped to a single OIDC client.""" + # Lazy import to avoid a circular import (payment.py imports this module) + # pylint: disable=import-outside-toplevel + from squarelet.organizations.models.payment import EntitlementGrant + + plan_q = Q(plans__organizations=org) + grant_pks = [g.pk for g in EntitlementGrant.objects.matching(org)] + grant_q = Q(grants__in=grant_pks) + + qs = self.filter(plan_q | grant_q) + if client is not None: + qs = qs.filter(client=client) + return qs.distinct() + + +class EntitlementGrantQuerySet(models.QuerySet): + def active(self): + return self.filter(active=True) + + def matching(self, org): + """Return active grants applying to `org`, as a list. + + Prefetches `organizations` and `entitlements` so callers can iterate + related rows without further queries. + """ + org_type_filter = ( + Q(for_individuals=True) if org.individual else Q(for_groups=True) + ) + candidates = ( + self.active() + .filter(org_type_filter) + .prefetch_related("organizations", "entitlements") + ) + return [g for g in candidates if g.matches(org)] + class InvitationQuerySet(models.QuerySet): def get_open(self): diff --git a/squarelet/organizations/serializers.py b/squarelet/organizations/serializers.py index fdaff81eb..972ca9d05 100644 --- a/squarelet/organizations/serializers.py +++ b/squarelet/organizations/serializers.py @@ -1,6 +1,3 @@ -# Django -from django.db.models.expressions import F - # Third Party import stripe from rest_framework import serializers, status @@ -8,7 +5,12 @@ # Squarelet from squarelet.core.utils import format_stripe_error -from squarelet.organizations.models import Charge, Membership, Organization +from squarelet.organizations.models import ( + Charge, + Entitlement, + Membership, + Organization, +) class OrganizationSerializer(serializers.ModelSerializer): @@ -85,19 +87,28 @@ def get_update_on(self, _obj): return None def get_entitlements(self, obj): + # Get the client first client = self.context.get("client") if not client: request = self.context.get("request") if request and hasattr(request, "auth") and request.auth: client = request.auth.client + # If we can't find a client, then no entitlements. + if not client: + return [] - if client: - return list( - client.entitlements.filter(plans__organizations=obj) - .annotate(update_on=F("plans__subscriptions__update_on")) - .values("name", "slug", "description", "resources", "update_on") + # With a client, get entitlements for the org + entitlements = list( + Entitlement.objects.for_organization(obj, client=client).values( + "name", "slug", "description", "resources" ) - return [] + ) + + sub = obj.subscription + sub_update_on = sub.update_on if sub else None + for row in entitlements: + row["update_on"] = sub_update_on + return entitlements def get_card(self, obj): # this can be slow - goes to stripe for customer/card info - cache this diff --git a/squarelet/organizations/tests/factories.py b/squarelet/organizations/tests/factories.py index 428b76490..d2c195387 100644 --- a/squarelet/organizations/tests/factories.py +++ b/squarelet/organizations/tests/factories.py @@ -196,6 +196,26 @@ class Meta: model = "organizations.Entitlement" +class EntitlementGrantFactory(factory.django.DjangoModelFactory): + name = factory.Sequence(lambda n: f"Grant {n}") + description = factory.Sequence(lambda n: f"Grant description {n}") + + class Meta: + model = "organizations.EntitlementGrant" + + @factory.post_generation + def entitlements(self, create, extracted, **kwargs): + if not create or not extracted: + return + self.entitlements.set(extracted) + + @factory.post_generation + def organizations(self, create, extracted, **kwargs): + if not create or not extracted: + return + self.organizations.set(extracted) + + class EmailDomainFactory(factory.django.DjangoModelFactory): organization = factory.SubFactory( "squarelet.organizations.tests.factories.OrganizationFactory" diff --git a/squarelet/organizations/tests/models/test_entitlement_grant.py b/squarelet/organizations/tests/models/test_entitlement_grant.py new file mode 100644 index 000000000..ed70e85dc --- /dev/null +++ b/squarelet/organizations/tests/models/test_entitlement_grant.py @@ -0,0 +1,269 @@ +# Django +from django.db.models import QuerySet + +# Third Party +import pytest + +# Squarelet +from squarelet.oidc.tests.factories import ClientFactory +from squarelet.organizations.models import Entitlement +from squarelet.organizations.tests.factories import ( + EntitlementFactory, + EntitlementGrantFactory, + OrganizationFactory, + PlanFactory, + SubscriptionFactory, +) + + +class TestEntitlementGrant: + """Unit tests for EntitlementGrant.matches""" + + @pytest.mark.django_db() + def test_grant_matches_explicit_org(self): + """ + Entitlements may be granted to individual orgs. + """ + org = OrganizationFactory() + other_org = OrganizationFactory() + grant = EntitlementGrantFactory(organizations=[org]) + assert grant.matches(org) is True + assert grant.matches(other_org) is False + + @pytest.mark.django_db() + def test_grant_no_criteria_no_explicit_does_not_match(self): + """ + We need to be precise with our entitlement grants. + We don't do broad entitlement grants—orgs _must_ meet the criteria. + """ + org = OrganizationFactory() + grant = EntitlementGrantFactory() + assert grant.matches(org) is False + + @pytest.mark.django_db() + def test_grant_matches_verified_org(self): + """ + We can give entitlements to all verified organizations. + """ + verified = OrganizationFactory(verified_journalist=True) + unverified = OrganizationFactory(verified_journalist=False) + grant = EntitlementGrantFactory(require_verified=True) + assert grant.matches(verified) is True + assert grant.matches(unverified) is False + + @pytest.mark.django_db() + def test_grant_matches_active_subscription_org(self): + """ + We can grant entitlements to our active subscribers. + """ + subscribed = OrganizationFactory() + SubscriptionFactory(organization=subscribed) + unsubscribed = OrganizationFactory() + grant = EntitlementGrantFactory(require_active_subscription=True) + assert grant.matches(subscribed) is True + assert grant.matches(unsubscribed) is False + + @pytest.mark.django_db() + def test_grant_requires_all_checked_criteria(self): + """ + Grant rules are combined with "AND" logic. + """ + verified_and_sub = OrganizationFactory(verified_journalist=True) + SubscriptionFactory(organization=verified_and_sub) + verified_only = OrganizationFactory(verified_journalist=True) + sub_only = OrganizationFactory(verified_journalist=False) + SubscriptionFactory(organization=sub_only) + + grant = EntitlementGrantFactory( + require_verified=True, require_active_subscription=True + ) + assert grant.matches(verified_and_sub) is True + assert grant.matches(verified_only) is False + assert grant.matches(sub_only) is False + + @pytest.mark.django_db() + def test_grant_explicit_overrides_criteria(self): + """ + Granting an entitlement to an org ignores our rules. + """ + unverified = OrganizationFactory(verified_journalist=False) + grant = EntitlementGrantFactory( + organizations=[unverified], require_verified=True + ) + assert grant.matches(unverified) is True + + @pytest.mark.django_db() + def test_inactive_grant_does_not_match(self): + """ + When a grant is inactive, nobody gets it. + """ + org = OrganizationFactory(verified_journalist=True) + grant = EntitlementGrantFactory( + organizations=[org], require_verified=True, active=False + ) + assert grant.matches(org) is False + + @pytest.mark.django_db() + def test_grant_for_individuals_only(self): + """ + Entitlements may just be granted to individuals. + """ + individual = OrganizationFactory(individual=True, verified_journalist=True) + group = OrganizationFactory(individual=False, verified_journalist=True) + grant = EntitlementGrantFactory( + require_verified=True, for_individuals=True, for_groups=False + ) + assert grant.matches(individual) is True + assert grant.matches(group) is False + + @pytest.mark.django_db() + def test_grant_for_groups_only(self): + """ + Entitlements may also be granted to groups. + """ + individual = OrganizationFactory(individual=True, verified_journalist=True) + group = OrganizationFactory(individual=False, verified_journalist=True) + grant = EntitlementGrantFactory( + require_verified=True, for_individuals=False, for_groups=True + ) + assert grant.matches(individual) is False + assert grant.matches(group) is True + + @pytest.mark.django_db() + def test_grant_for_both_by_default(self): + """ + Entitlements may be granted to individuals _and_ groups. + """ + individual = OrganizationFactory(individual=True, verified_journalist=True) + group = OrganizationFactory(individual=False, verified_journalist=True) + grant = EntitlementGrantFactory(require_verified=True) + assert grant.matches(individual) is True + assert grant.matches(group) is True + + @pytest.mark.django_db() + def test_org_type_filter_applies_to_explicit_grants(self): + """An individual explicitly listed on a groups-only grant is still excluded.""" + individual = OrganizationFactory(individual=True) + grant = EntitlementGrantFactory( + organizations=[individual], for_individuals=False, for_groups=True + ) + assert grant.matches(individual) is False + + +class TestEntitlementForOrganization: + """Tests for Entitlement.objects.for_organization manager method""" + + @pytest.mark.django_db() + def test_includes_plan_entitlements(self): + plan = PlanFactory() + entitlement = EntitlementFactory() + entitlement.plans.set([plan]) + org = OrganizationFactory(plans=[plan]) + assert entitlement in Entitlement.objects.for_organization(org) + + @pytest.mark.django_db() + def test_includes_explicit_grant_entitlements(self): + org = OrganizationFactory() + entitlement = EntitlementFactory() + EntitlementGrantFactory(organizations=[org], entitlements=[entitlement]) + assert entitlement in Entitlement.objects.for_organization(org) + + @pytest.mark.django_db() + def test_dedupes_plan_and_grant(self): + plan = PlanFactory() + entitlement = EntitlementFactory() + entitlement.plans.set([plan]) + org = OrganizationFactory(plans=[plan]) + EntitlementGrantFactory(organizations=[org], entitlements=[entitlement]) + qs = Entitlement.objects.for_organization(org) + assert qs.filter(pk=entitlement.pk).count() == 1 + + @pytest.mark.django_db() + def test_filters_by_client(self): + client1 = ClientFactory() + client2 = ClientFactory() + e1 = EntitlementFactory(client=client1) + e2 = EntitlementFactory(client=client2) + org = OrganizationFactory() + EntitlementGrantFactory(organizations=[org], entitlements=[e1, e2]) + + result = Entitlement.objects.for_organization(org, client=client1) + assert e1 in result + assert e2 not in result + + @pytest.mark.django_db() + def test_verified_rule_matches(self): + verified = OrganizationFactory(verified_journalist=True) + unverified = OrganizationFactory(verified_journalist=False) + entitlement = EntitlementFactory() + EntitlementGrantFactory(require_verified=True, entitlements=[entitlement]) + assert entitlement in Entitlement.objects.for_organization(verified) + assert entitlement not in Entitlement.objects.for_organization(unverified) + + @pytest.mark.django_db() + def test_active_subscription_rule_matches(self): + subscribed = OrganizationFactory() + SubscriptionFactory(organization=subscribed) + unsubscribed = OrganizationFactory() + entitlement = EntitlementFactory() + EntitlementGrantFactory( + require_active_subscription=True, entitlements=[entitlement] + ) + assert entitlement in Entitlement.objects.for_organization(subscribed) + assert entitlement not in Entitlement.objects.for_organization(unsubscribed) + + @pytest.mark.django_db() + def test_both_rules_require_both_at_db_level(self): + verified_and_sub = OrganizationFactory(verified_journalist=True) + SubscriptionFactory(organization=verified_and_sub) + verified_only = OrganizationFactory(verified_journalist=True) + entitlement = EntitlementFactory() + EntitlementGrantFactory( + require_verified=True, + require_active_subscription=True, + entitlements=[entitlement], + ) + assert entitlement in Entitlement.objects.for_organization(verified_and_sub) + assert entitlement not in Entitlement.objects.for_organization(verified_only) + + @pytest.mark.django_db() + def test_explicit_membership_bypasses_criteria_at_db_level(self): + unverified = OrganizationFactory(verified_journalist=False) + entitlement = EntitlementFactory() + EntitlementGrantFactory( + organizations=[unverified], + require_verified=True, + entitlements=[entitlement], + ) + assert entitlement in Entitlement.objects.for_organization(unverified) + + @pytest.mark.django_db() + def test_excludes_inactive_grants(self): + org = OrganizationFactory() + entitlement = EntitlementFactory() + EntitlementGrantFactory( + organizations=[org], entitlements=[entitlement], active=False + ) + assert entitlement not in Entitlement.objects.for_organization(org) + + @pytest.mark.django_db() + def test_excludes_grants_for_wrong_org_type_at_db_level(self): + individual = OrganizationFactory(individual=True, verified_journalist=True) + group = OrganizationFactory(individual=False, verified_journalist=True) + entitlement = EntitlementFactory() + EntitlementGrantFactory( + require_verified=True, + for_individuals=False, + for_groups=True, + entitlements=[entitlement], + ) + assert entitlement in Entitlement.objects.for_organization(group) + assert entitlement not in Entitlement.objects.for_organization(individual) + + @pytest.mark.django_db() + def test_returns_queryset_instance(self): + org = OrganizationFactory() + result = Entitlement.objects.for_organization(org) + assert isinstance(result, QuerySet) + # Chainable + assert not list(result.values_list("pk", flat=True)) diff --git a/squarelet/organizations/tests/models/test_organization.py b/squarelet/organizations/tests/models/test_organization.py index 0ce8963bc..a06d7fedb 100644 --- a/squarelet/organizations/tests/models/test_organization.py +++ b/squarelet/organizations/tests/models/test_organization.py @@ -599,7 +599,7 @@ def test_merge_fks(self): if f.is_relation and f.auto_created ] ) - == 17 + == 18 ) # Many to many relations defined on the Organization model assert ( diff --git a/squarelet/organizations/tests/test_serializers.py b/squarelet/organizations/tests/test_serializers.py new file mode 100644 index 000000000..ad7922b1d --- /dev/null +++ b/squarelet/organizations/tests/test_serializers.py @@ -0,0 +1,72 @@ +# Django +from django.utils import timezone + +# Standard Library +from datetime import timedelta + +# Third Party +import pytest + +# Squarelet +from squarelet.oidc.tests.factories import ClientFactory +from squarelet.organizations.serializers import OrganizationDetailSerializer +from squarelet.organizations.tests.factories import ( + EntitlementFactory, + EntitlementGrantFactory, + OrganizationFactory, + PlanFactory, + SubscriptionFactory, +) + + +class TestSerializerEntitlements: + """Integration tests for OrganizationDetailSerializer.get_entitlements""" + + @pytest.mark.django_db() + def test_serializer_returns_grant_entitlements(self): + client = ClientFactory() + entitlement = EntitlementFactory(client=client) + org = OrganizationFactory() + EntitlementGrantFactory(organizations=[org], entitlements=[entitlement]) + + serializer = OrganizationDetailSerializer(org, context={"client": client}) + rows = serializer.get_entitlements(org) + slugs = [e["slug"] for e in rows] + assert entitlement.slug in slugs + + @pytest.mark.django_db() + def test_serializer_grant_entitlement_update_on_is_none(self): + client = ClientFactory() + entitlement = EntitlementFactory(client=client) + org = OrganizationFactory() + EntitlementGrantFactory(organizations=[org], entitlements=[entitlement]) + + serializer = OrganizationDetailSerializer(org, context={"client": client}) + rows = [ + e for e in serializer.get_entitlements(org) if e["slug"] == entitlement.slug + ] + assert len(rows) == 1 + assert rows[0]["update_on"] is None + + @pytest.mark.django_db() + def test_serializer_update_on_uses_subscription(self): + client = ClientFactory() + entitlement = EntitlementFactory(client=client) + plan = PlanFactory() + entitlement.plans.set([plan]) + org = OrganizationFactory() + sub_update_on = timezone.now().date() + timedelta(days=7) + SubscriptionFactory(organization=org, plan=plan, update_on=sub_update_on) + + serializer = OrganizationDetailSerializer(org, context={"client": client}) + rows = [ + e for e in serializer.get_entitlements(org) if e["slug"] == entitlement.slug + ] + assert len(rows) == 1 + assert rows[0]["update_on"] == sub_update_on + + @pytest.mark.django_db() + def test_serializer_returns_empty_without_client(self): + org = OrganizationFactory() + serializer = OrganizationDetailSerializer(org, context={}) + assert not serializer.get_entitlements(org) From e87d75e36021a44943843a31c742c05488b4881e Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 19 May 2026 15:30:27 -0400 Subject: [PATCH 2/8] Replenishes entitlement grants on the same cadences as subscription entitlements (1 month cadence) --- squarelet/organizations/admin.py | 11 ++ .../0063_entitlementgrant_update_on.py | 37 +++++ squarelet/organizations/models/payment.py | 52 +++++++ squarelet/organizations/querysets.py | 6 + squarelet/organizations/serializers.py | 19 ++- squarelet/organizations/signals.py | 131 +++++++++++++++++- squarelet/organizations/tasks.py | 35 +++-- .../tests/models/test_entitlement_grant.py | 91 ++++++++++++ .../organizations/tests/test_serializers.py | 30 +++- squarelet/organizations/tests/test_signals.py | 120 +++++++++++++++- squarelet/organizations/tests/test_tasks.py | 61 +++++++- 11 files changed, 575 insertions(+), 18 deletions(-) create mode 100644 squarelet/organizations/migrations/0063_entitlementgrant_update_on.py diff --git a/squarelet/organizations/admin.py b/squarelet/organizations/admin.py index 6aaf72539..6f2e72fa4 100644 --- a/squarelet/organizations/admin.py +++ b/squarelet/organizations/admin.py @@ -526,6 +526,7 @@ class EntitlementGrantAdmin(VersionAdmin): "for_groups", "require_verified", "require_active_subscription", + "update_on", ) list_filter = ( "active", @@ -548,6 +549,16 @@ class EntitlementGrantAdmin(VersionAdmin): "Rule-based grants", {"fields": ("require_verified", "require_active_subscription")}, ), + ( + "Refresh", + { + "fields": ("update_on",), + "description": ( + "Leave blank to default to one month from creation. " + "Resources tied to this grant refresh on this date." + ), + }, + ), ) diff --git a/squarelet/organizations/migrations/0063_entitlementgrant_update_on.py b/squarelet/organizations/migrations/0063_entitlementgrant_update_on.py new file mode 100644 index 000000000..77748827d --- /dev/null +++ b/squarelet/organizations/migrations/0063_entitlementgrant_update_on.py @@ -0,0 +1,37 @@ +# Generated by Django 5.2.12 on 2026-05-19 17:51 + +from datetime import date + +from dateutil.relativedelta import relativedelta +from django.db import migrations, models + + +def backfill_update_on(apps, schema_editor): + """Populate update_on for any pre-existing grants so admins see consistent + values. NULL <= today is false in SQL, so the celery task wouldn't sweep + these without backfill, but admins inspecting the table would see blanks.""" + EntitlementGrant = apps.get_model("organizations", "EntitlementGrant") + EntitlementGrant.objects.filter(update_on__isnull=True).update( + update_on=date.today() + relativedelta(months=1) + ) + + +class Migration(migrations.Migration): + + dependencies = [ + ("organizations", "0062_alter_plan_slug_entitlementgrant"), + ] + + operations = [ + migrations.AddField( + model_name="entitlementgrant", + name="update_on", + field=models.DateField( + blank=True, + help_text="Date when this grant's resources next refresh", + null=True, + verbose_name="date update", + ), + ), + migrations.RunPython(backfill_update_on, migrations.RunPython.noop), + ] diff --git a/squarelet/organizations/models/payment.py b/squarelet/organizations/models/payment.py index bebe8cb3c..fc001ec30 100644 --- a/squarelet/organizations/models/payment.py +++ b/squarelet/organizations/models/payment.py @@ -1,7 +1,9 @@ # Django from django.conf import settings from django.db import models, transaction +from django.db.models import Q from django.urls import reverse +from django.utils import timezone from django.utils.translation import gettext_lazy as _ # Standard Library @@ -11,6 +13,7 @@ # Third Party import stripe from autoslug import AutoSlugField +from dateutil.relativedelta import relativedelta from memoize import mproperty # Squarelet @@ -811,6 +814,13 @@ class EntitlementGrant(models.Model): help_text=_("Inactive grants do not apply to any organization"), ) + update_on = models.DateField( + _("date update"), + null=True, + blank=True, + help_text=_("Date when this grant's resources next refresh"), + ) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -822,6 +832,11 @@ class Meta: def __str__(self): return self.name + def save(self, *args, **kwargs): + if self.update_on is None: + self.update_on = timezone.now().date() + relativedelta(months=1) + super().save(*args, **kwargs) + def matches(self, org): if not self.active: return False @@ -842,6 +857,43 @@ def matches(self, org): return False return all(checks) + def matching_organizations(self): + """Return queryset of organizations this grant currently matches. + + Reverse of `matches(org)`. Used by the celery refresh task and by signal + handlers to compute the set of orgs whose cache must be invalidated. + """ + # pylint: disable=import-outside-toplevel + from squarelet.organizations.models.organization import Organization + + if not self.active: + return Organization.objects.none() + + if self.for_individuals and self.for_groups: + eligible = Organization.objects.all() + elif self.for_individuals: + eligible = Organization.objects.filter(individual=True) + elif self.for_groups: + eligible = Organization.objects.filter(individual=False) + else: + return Organization.objects.none() + + explicit_q = Q(entitlement_grants=self) + + rule_clauses = [] + if self.require_verified: + rule_clauses.append(Q(verified_journalist=True)) + if self.require_active_subscription: + # Mirrors org.has_active_subscription() = bool(subscriptions.first()) + rule_clauses.append(Q(subscriptions__isnull=False)) + + if rule_clauses: + rule_q = rule_clauses[0] + for clause in rule_clauses[1:]: + rule_q &= clause + return eligible.filter(explicit_q | rule_q).distinct() + return eligible.filter(explicit_q).distinct() + class ReceiptEmail(models.Model): """An email address to send receipts to""" diff --git a/squarelet/organizations/querysets.py b/squarelet/organizations/querysets.py index caf0a7cb7..5ed3b8739 100644 --- a/squarelet/organizations/querysets.py +++ b/squarelet/organizations/querysets.py @@ -185,6 +185,12 @@ class EntitlementGrantQuerySet(models.QuerySet): def active(self): return self.filter(active=True) + def expired(self, on_date=None): + """Grants whose update_on has arrived or passed.""" + if on_date is None: + on_date = timezone.now().date() + return self.filter(update_on__lte=on_date) + def matching(self, org): """Return active grants applying to `org`, as a list. diff --git a/squarelet/organizations/serializers.py b/squarelet/organizations/serializers.py index 972ca9d05..b2a62273d 100644 --- a/squarelet/organizations/serializers.py +++ b/squarelet/organizations/serializers.py @@ -11,6 +11,7 @@ Membership, Organization, ) +from squarelet.organizations.models.payment import EntitlementGrant class OrganizationSerializer(serializers.ModelSerializer): @@ -100,14 +101,28 @@ def get_entitlements(self, obj): # With a client, get entitlements for the org entitlements = list( Entitlement.objects.for_organization(obj, client=client).values( - "name", "slug", "description", "resources" + "pk", "name", "slug", "description", "resources" ) ) sub = obj.subscription sub_update_on = sub.update_on if sub else None + + # For grant-derived entitlements (no subscription), report the soonest + # matching grant's update_on per entitlement. + grant_update_on = {} + if sub_update_on is None: + for grant in EntitlementGrant.objects.matching(obj): + if grant.update_on is None: + continue + for ent in grant.entitlements.all(): + existing = grant_update_on.get(ent.pk) + if existing is None or grant.update_on < existing: + grant_update_on[ent.pk] = grant.update_on + for row in entitlements: - row["update_on"] = sub_update_on + row["update_on"] = sub_update_on or grant_update_on.get(row["pk"]) + del row["pk"] return entitlements def get_card(self, obj): diff --git a/squarelet/organizations/signals.py b/squarelet/organizations/signals.py index dcec3e4d8..c4113c2e8 100644 --- a/squarelet/organizations/signals.py +++ b/squarelet/organizations/signals.py @@ -7,13 +7,14 @@ from actstream import registry # Squarelet +from squarelet.oidc.middleware import send_cache_invalidations from squarelet.organizations.models import ( Invitation, Organization, Plan, ProfileChangeRequest, ) -from squarelet.organizations.models.payment import Charge +from squarelet.organizations.models.payment import Charge, EntitlementGrant from squarelet.organizations.tasks import sync_wix_for_group_member # Register models with django-activity-stream @@ -160,3 +161,131 @@ def charge_created(sender, instance, created, **kwargs): if created and instance.organization.hidden: instance.organization.hidden = False instance.organization.save(update_fields=["hidden"]) + + +# --- EntitlementGrant cache invalidation ------------------------------------- +# +# Admin actions on grants (create, edit, toggle active, delete, M2M edits) +# change the set of orgs that match a grant. Each change broadcasts cache +# invalidations for the affected orgs so OIDC clients re-fetch entitlements +# immediately. The monthly `restore_organization` task handles the scheduled +# refresh cycle; these signals handle the interactive path. + + +def _invalidate_orgs(uuids): + """Defer a cache-invalidation broadcast for the given org UUIDs.""" + uuid_list = list({str(u) for u in uuids}) + if not uuid_list: + return + transaction.on_commit(lambda: send_cache_invalidations("organization", uuid_list)) + + +@receiver( + signals.pre_save, + sender=EntitlementGrant, + dispatch_uid="squarelet.organizations.signals.entitlementgrant_stash_pre_save", +) +def entitlementgrant_stash_pre_save(sender, instance, **kwargs): + """Stash the orgs this grant matched *before* the save. + + Needed because toggling active=False or flipping rules can shrink the + matching set — post_save alone wouldn't see who used to match. + """ + # pylint: disable=unused-argument,protected-access + if instance.pk is None: + instance._pre_save_match_uuids = [] + return + try: + old = EntitlementGrant.objects.get(pk=instance.pk) + except EntitlementGrant.DoesNotExist: + instance._pre_save_match_uuids = [] + return + instance._pre_save_match_uuids = list( + old.matching_organizations().values_list("uuid", flat=True) + ) + + +@receiver( + signals.post_save, + sender=EntitlementGrant, + dispatch_uid="squarelet.organizations.signals.entitlementgrant_invalidate_on_save", +) +def entitlementgrant_invalidate_on_save(sender, instance, **kwargs): + """Broadcast for the union of pre-save and post-save matches.""" + # pylint: disable=unused-argument + pre = getattr(instance, "_pre_save_match_uuids", []) or [] + post = list(instance.matching_organizations().values_list("uuid", flat=True)) + _invalidate_orgs(set(pre) | set(post)) + + +@receiver( + signals.pre_delete, + sender=EntitlementGrant, + dispatch_uid="squarelet.organizations.signals.entitlementgrant_stash_pre_delete", +) +def entitlementgrant_stash_pre_delete(sender, instance, **kwargs): + """Stash the orgs this grant matched before delete cascades the M2M.""" + # pylint: disable=unused-argument,protected-access + instance._pre_delete_match_uuids = list( + instance.matching_organizations().values_list("uuid", flat=True) + ) + + +@receiver( + signals.post_delete, + sender=EntitlementGrant, + dispatch_uid=( + "squarelet.organizations.signals.entitlementgrant_invalidate_on_delete" + ), +) +def entitlementgrant_invalidate_on_delete(sender, instance, **kwargs): + """Broadcast for orgs that matched immediately before the delete.""" + # pylint: disable=unused-argument + _invalidate_orgs(getattr(instance, "_pre_delete_match_uuids", []) or []) + + +@receiver( + signals.m2m_changed, + sender=EntitlementGrant.organizations.through, + dispatch_uid=( + "squarelet.organizations.signals.entitlementgrant_organizations_m2m_changed" + ), +) +def entitlementgrant_organizations_m2m_changed( + sender, instance, action, pk_set, reverse, **kwargs +): + """Broadcast when orgs are added to or removed from a grant's M2M.""" + # pylint: disable=unused-argument + if action not in {"post_add", "post_remove", "post_clear"}: + return + if reverse: + # Reverse: instance is an Organization that just gained/lost a grant. + _invalidate_orgs([instance.uuid]) + return + if not pk_set: + # v1: bare `clear()` is not handled — the admin UI uses add/remove. + return + uuids = list( + Organization.objects.filter(pk__in=pk_set).values_list("uuid", flat=True) + ) + _invalidate_orgs(uuids) + + +@receiver( + signals.m2m_changed, + sender=EntitlementGrant.entitlements.through, + dispatch_uid=( + "squarelet.organizations.signals.entitlementgrant_entitlements_m2m_changed" + ), +) +def entitlementgrant_entitlements_m2m_changed( + sender, instance, action, reverse, **kwargs +): + """Broadcast for currently-matching orgs when a grant's entitlements change.""" + # pylint: disable=unused-argument + if action not in {"post_add", "post_remove", "post_clear"}: + return + if reverse: + return # v1: skip reverse path + uuids = list(instance.matching_organizations().values_list("uuid", flat=True)) + _invalidate_orgs(uuids) diff --git a/squarelet/organizations/tasks.py b/squarelet/organizations/tasks.py index f1368e2b3..2b63c1e56 100644 --- a/squarelet/organizations/tasks.py +++ b/squarelet/organizations/tasks.py @@ -22,7 +22,12 @@ from squarelet.organizations import wix from squarelet.organizations.models.invoice import Invoice from squarelet.organizations.models.organization import Organization -from squarelet.organizations.models.payment import Charge, Plan, Subscription +from squarelet.organizations.models.payment import ( + Charge, + EntitlementGrant, + Plan, + Subscription, +) from squarelet.organizations.payments.factory import get_payment_provider from squarelet.users.models import User @@ -31,17 +36,29 @@ @shared_task def restore_organization(): - """Monthly update of organizations subscriptions""" - subscriptions = Subscription.objects.filter(update_on__lte=date.today()) - - # convert to a list so it can be serialized by celery - uuids = list(subscriptions.values_list("organization__uuid", flat=True)) + """Monthly refresh of subscriptions and entitlement grants""" + today = date.today() - # delete cancelled subscriptions first + # --- Subscriptions --- + subscriptions = Subscription.objects.filter(update_on__lte=today) + sub_uuids = list(subscriptions.values_list("organization__uuid", flat=True)) subscriptions.filter(cancelled=True).delete() - subscriptions.update(update_on=date.today() + Interval("1 month")) + subscriptions.update(update_on=today + Interval("1 month")) + + # --- Entitlement grants --- + # Snapshot the matching orgs for each active expired grant before bumping + # update_on, then bulk-bump in one query. Inactive expired grants get the + # bump too but contribute no UUIDs to the broadcast. + expired_grants = EntitlementGrant.objects.expired(today) + grant_uuids = set() + for grant in expired_grants.filter(active=True): + grant_uuids.update( + grant.matching_organizations().values_list("uuid", flat=True) + ) + expired_grants.update(update_on=today + Interval("1 month")) - send_cache_invalidations("organization", uuids) + all_uuids = list({*sub_uuids, *grant_uuids}) + send_cache_invalidations("organization", all_uuids) @shared_task( diff --git a/squarelet/organizations/tests/models/test_entitlement_grant.py b/squarelet/organizations/tests/models/test_entitlement_grant.py index ed70e85dc..b5376e29c 100644 --- a/squarelet/organizations/tests/models/test_entitlement_grant.py +++ b/squarelet/organizations/tests/models/test_entitlement_grant.py @@ -1,8 +1,13 @@ # Django from django.db.models import QuerySet +from django.utils import timezone + +# Standard Library +from datetime import timedelta # Third Party import pytest +from dateutil.relativedelta import relativedelta # Squarelet from squarelet.oidc.tests.factories import ClientFactory @@ -149,6 +154,92 @@ def test_org_type_filter_applies_to_explicit_grants(self): ) assert grant.matches(individual) is False + @pytest.mark.django_db() + def test_update_on_defaults_to_one_month_out(self): + """A grant saved without update_on defaults to one month from today.""" + grant = EntitlementGrantFactory() + expected = timezone.now().date() + relativedelta(months=1) + assert grant.update_on == expected + + @pytest.mark.django_db() + def test_update_on_admin_provided_value_is_preserved(self): + """An admin-provided update_on is left alone on save.""" + future = timezone.now().date() + timedelta(days=90) + grant = EntitlementGrantFactory(update_on=future) + assert grant.update_on == future + + +class TestMatchingOrganizations: + """Tests for EntitlementGrant.matching_organizations""" + + @pytest.mark.django_db() + def test_explicit_membership_only(self): + org = OrganizationFactory() + other = OrganizationFactory() + grant = EntitlementGrantFactory(organizations=[org]) + matched = list(grant.matching_organizations()) + assert org in matched + assert other not in matched + + @pytest.mark.django_db() + def test_verified_rule(self): + verified = OrganizationFactory(verified_journalist=True) + unverified = OrganizationFactory(verified_journalist=False) + grant = EntitlementGrantFactory(require_verified=True) + matched = list(grant.matching_organizations()) + assert verified in matched + assert unverified not in matched + + @pytest.mark.django_db() + def test_active_subscription_rule(self): + subscribed = OrganizationFactory() + SubscriptionFactory(organization=subscribed) + unsubscribed = OrganizationFactory() + grant = EntitlementGrantFactory(require_active_subscription=True) + matched = list(grant.matching_organizations()) + assert subscribed in matched + assert unsubscribed not in matched + + @pytest.mark.django_db() + def test_org_type_filter(self): + individual = OrganizationFactory(individual=True, verified_journalist=True) + group = OrganizationFactory(individual=False, verified_journalist=True) + grant = EntitlementGrantFactory( + require_verified=True, for_individuals=False, for_groups=True + ) + matched = list(grant.matching_organizations()) + assert group in matched + assert individual not in matched + + @pytest.mark.django_db() + def test_inactive_grant_matches_nobody(self): + org = OrganizationFactory(verified_journalist=True) + grant = EntitlementGrantFactory( + organizations=[org], require_verified=True, active=False + ) + assert not list(grant.matching_organizations()) + + @pytest.mark.django_db() + def test_explicit_membership_respects_org_type_filter(self): + """An individual explicitly listed on a groups-only grant should not match.""" + individual = OrganizationFactory(individual=True) + grant = EntitlementGrantFactory( + organizations=[individual], for_individuals=False, for_groups=True + ) + assert individual not in grant.matching_organizations() + + @pytest.mark.django_db() + def test_both_rules_and_logic(self): + verified_and_sub = OrganizationFactory(verified_journalist=True) + SubscriptionFactory(organization=verified_and_sub) + verified_only = OrganizationFactory(verified_journalist=True) + grant = EntitlementGrantFactory( + require_verified=True, require_active_subscription=True + ) + matched = list(grant.matching_organizations()) + assert verified_and_sub in matched + assert verified_only not in matched + class TestEntitlementForOrganization: """Tests for Entitlement.objects.for_organization manager method""" diff --git a/squarelet/organizations/tests/test_serializers.py b/squarelet/organizations/tests/test_serializers.py index ad7922b1d..24fed1ce6 100644 --- a/squarelet/organizations/tests/test_serializers.py +++ b/squarelet/organizations/tests/test_serializers.py @@ -35,18 +35,42 @@ def test_serializer_returns_grant_entitlements(self): assert entitlement.slug in slugs @pytest.mark.django_db() - def test_serializer_grant_entitlement_update_on_is_none(self): + def test_serializer_returns_grant_update_on_when_no_subscription(self): client = ClientFactory() entitlement = EntitlementFactory(client=client) org = OrganizationFactory() - EntitlementGrantFactory(organizations=[org], entitlements=[entitlement]) + update_on = timezone.now().date() + timedelta(days=14) + EntitlementGrantFactory( + organizations=[org], entitlements=[entitlement], update_on=update_on + ) + + serializer = OrganizationDetailSerializer(org, context={"client": client}) + rows = [ + e for e in serializer.get_entitlements(org) if e["slug"] == entitlement.slug + ] + assert len(rows) == 1 + assert rows[0]["update_on"] == update_on + + @pytest.mark.django_db() + def test_serializer_picks_soonest_grant_update_on(self): + client = ClientFactory() + entitlement = EntitlementFactory(client=client) + org = OrganizationFactory() + sooner = timezone.now().date() + timedelta(days=7) + later = timezone.now().date() + timedelta(days=30) + EntitlementGrantFactory( + organizations=[org], entitlements=[entitlement], update_on=later + ) + EntitlementGrantFactory( + organizations=[org], entitlements=[entitlement], update_on=sooner + ) serializer = OrganizationDetailSerializer(org, context={"client": client}) rows = [ e for e in serializer.get_entitlements(org) if e["slug"] == entitlement.slug ] assert len(rows) == 1 - assert rows[0]["update_on"] is None + assert rows[0]["update_on"] == sooner @pytest.mark.django_db() def test_serializer_update_on_uses_subscription(self): diff --git a/squarelet/organizations/tests/test_signals.py b/squarelet/organizations/tests/test_signals.py index 52cedbdd7..ed0efeddd 100644 --- a/squarelet/organizations/tests/test_signals.py +++ b/squarelet/organizations/tests/test_signals.py @@ -3,8 +3,12 @@ # Squarelet from squarelet.organizations import signals -from squarelet.organizations.models.payment import Charge -from squarelet.organizations.tests.factories import OrganizationFactory +from squarelet.organizations.models.payment import Charge, EntitlementGrant +from squarelet.organizations.tests.factories import ( + EntitlementFactory, + EntitlementGrantFactory, + OrganizationFactory, +) @pytest.mark.django_db @@ -33,3 +37,115 @@ def test_charge_unhides_group_org(): org.refresh_from_db() assert org.hidden is False + + +def _broadcast_uuids(mock_send): + """Collect all UUIDs broadcast across all calls to send_cache_invalidations.""" + uuids = set() + for call in mock_send.call_args_list: + assert call.args[0] == "organization" + uuids.update(str(u) for u in call.args[1]) + return uuids + + +class TestEntitlementGrantSignals: + """Tests for admin-driven cache invalidation on EntitlementGrant changes.""" + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_create_with_explicit_orgs( + self, mocker, django_capture_on_commit_callbacks + ): + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + org = OrganizationFactory() + with django_capture_on_commit_callbacks(execute=True): + EntitlementGrantFactory(organizations=[org]) + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_active_toggle_off( + self, mocker, django_capture_on_commit_callbacks + ): + org = OrganizationFactory() + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory(organizations=[org]) + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.active = False + grant.save() + # Toggling off must still broadcast the org that previously matched. + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_active_toggle_on( + self, mocker, django_capture_on_commit_callbacks + ): + org = OrganizationFactory() + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory(organizations=[org], active=False) + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.active = True + grant.save() + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_org_added_to_m2m( + self, mocker, django_capture_on_commit_callbacks + ): + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory() + org = OrganizationFactory() + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.organizations.add(org) + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_org_removed_from_m2m( + self, mocker, django_capture_on_commit_callbacks + ): + org = OrganizationFactory() + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory(organizations=[org]) + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.organizations.remove(org) + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_entitlements_changed( + self, mocker, django_capture_on_commit_callbacks + ): + org = OrganizationFactory() + entitlement = EntitlementFactory() + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory(organizations=[org]) + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.entitlements.add(entitlement) + assert str(org.uuid) in _broadcast_uuids(mock_send) + + @pytest.mark.django_db(transaction=True) + def test_invalidates_on_delete(self, mocker, django_capture_on_commit_callbacks): + org = OrganizationFactory() + with django_capture_on_commit_callbacks(execute=True): + grant = EntitlementGrantFactory(organizations=[org]) + mock_send = mocker.patch( + "squarelet.organizations.signals.send_cache_invalidations" + ) + with django_capture_on_commit_callbacks(execute=True): + grant.delete() + assert str(org.uuid) in _broadcast_uuids(mock_send) + assert not EntitlementGrant.objects.filter(pk=grant.pk).exists() diff --git a/squarelet/organizations/tests/test_tasks.py b/squarelet/organizations/tests/test_tasks.py index 6fec16f91..4fc72b77e 100644 --- a/squarelet/organizations/tests/test_tasks.py +++ b/squarelet/organizations/tests/test_tasks.py @@ -14,7 +14,12 @@ # Squarelet from squarelet.organizations import tasks from squarelet.organizations.models import Charge, Invoice, Subscription -from squarelet.organizations.tests.factories import InvoiceFactory, SubscriptionFactory +from squarelet.organizations.tests.factories import ( + EntitlementGrantFactory, + InvoiceFactory, + OrganizationFactory, + SubscriptionFactory, +) # pylint:disable=too-many-lines # TODO: Refactor this file and `tasks.py` into smaller files @@ -59,6 +64,60 @@ def test_restore_organization(organization_plan_factory, mocker): ) +class TestRestoreOrganizationGrants: + """Coverage for grant refresh in restore_organization.""" + + @pytest.mark.django_db() + def test_bumps_expired_active_grant_and_broadcasts(self, mocker): + mock_send = mocker.patch( + "squarelet.organizations.tasks.send_cache_invalidations" + ) + today = date.today() + org = OrganizationFactory(verified_journalist=True) + grant = EntitlementGrantFactory( + require_verified=True, update_on=today - timedelta(days=1) + ) + + tasks.restore_organization() + + grant.refresh_from_db() + assert grant.update_on == today + relativedelta(months=1) + assert mock_send.call_args[0][0] == "organization" + assert str(org.uuid) in {str(u) for u in mock_send.call_args[0][1]} + + @pytest.mark.django_db() + def test_future_grant_is_unchanged(self, mocker): + mocker.patch("squarelet.organizations.tasks.send_cache_invalidations") + today = date.today() + future = today + timedelta(days=14) + grant = EntitlementGrantFactory(require_verified=True, update_on=future) + + tasks.restore_organization() + + grant.refresh_from_db() + assert grant.update_on == future + + @pytest.mark.django_db() + def test_inactive_expired_grant_bumped_but_not_broadcast(self, mocker): + mock_send = mocker.patch( + "squarelet.organizations.tasks.send_cache_invalidations" + ) + today = date.today() + org = OrganizationFactory(verified_journalist=True) + grant = EntitlementGrantFactory( + organizations=[org], + active=False, + update_on=today - timedelta(days=1), + ) + + tasks.restore_organization() + + grant.refresh_from_db() + assert grant.update_on == today + relativedelta(months=1) + broadcast_uuids = {str(u) for u in mock_send.call_args[0][1]} + assert str(org.uuid) not in broadcast_uuids + + class TestHandleChargeSucceeded: """Unit tests for the handle_charge_succeeded task""" From 6fe2bad7a7b3030c6195d884a71115cbbf4ebaee Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 19 May 2026 15:52:49 -0400 Subject: [PATCH 3/8] Fix test --- squarelet/users/tests/test_oidc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/squarelet/users/tests/test_oidc.py b/squarelet/users/tests/test_oidc.py index fff64b09f..32caf11eb 100644 --- a/squarelet/users/tests/test_oidc.py +++ b/squarelet/users/tests/test_oidc.py @@ -39,7 +39,8 @@ def test_scope_organizations(user_factory, mocker): default_source=None, ) user = user_factory() - token = MagicMock(user=user) + # client=None so both serializer calls take the same (no-entitlements) path. + token = MagicMock(user=user, client=None) claims = oidc.CustomScopeClaims(token) info = claims.scope_organizations() assert info["organizations"] == [ From 4db80fac91054f20174f0612210f7a75e36875a2 Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 19 May 2026 16:03:42 -0400 Subject: [PATCH 4/8] sort imports --- squarelet/organizations/models/payment.py | 1 + squarelet/organizations/querysets.py | 1 + squarelet/organizations/serializers.py | 7 +------ 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/squarelet/organizations/models/payment.py b/squarelet/organizations/models/payment.py index fc001ec30..cef564e88 100644 --- a/squarelet/organizations/models/payment.py +++ b/squarelet/organizations/models/payment.py @@ -864,6 +864,7 @@ def matching_organizations(self): handlers to compute the set of orgs whose cache must be invalidated. """ # pylint: disable=import-outside-toplevel + # Squarelet from squarelet.organizations.models.organization import Organization if not self.active: diff --git a/squarelet/organizations/querysets.py b/squarelet/organizations/querysets.py index 5ed3b8739..e2006892c 100644 --- a/squarelet/organizations/querysets.py +++ b/squarelet/organizations/querysets.py @@ -169,6 +169,7 @@ def for_organization(self, org, client=None): currently available to `org`. Optionally scoped to a single OIDC client.""" # Lazy import to avoid a circular import (payment.py imports this module) # pylint: disable=import-outside-toplevel + # Squarelet from squarelet.organizations.models.payment import EntitlementGrant plan_q = Q(plans__organizations=org) diff --git a/squarelet/organizations/serializers.py b/squarelet/organizations/serializers.py index b2a62273d..67faf8f4c 100644 --- a/squarelet/organizations/serializers.py +++ b/squarelet/organizations/serializers.py @@ -5,12 +5,7 @@ # Squarelet from squarelet.core.utils import format_stripe_error -from squarelet.organizations.models import ( - Charge, - Entitlement, - Membership, - Organization, -) +from squarelet.organizations.models import Charge, Entitlement, Membership, Organization from squarelet.organizations.models.payment import EntitlementGrant From 9309511ab3f40fc9a315504e95776b4e8c091001 Mon Sep 17 00:00:00 2001 From: Mitchell Kotler Date: Thu, 21 May 2026 13:56:08 -0400 Subject: [PATCH 5/8] optimize sql --- squarelet/organizations/querysets.py | 40 ++++++++++++++++---------- squarelet/organizations/serializers.py | 16 ++++++++--- squarelet/organizations/tasks.py | 11 ++++--- 3 files changed, 44 insertions(+), 23 deletions(-) diff --git a/squarelet/organizations/querysets.py b/squarelet/organizations/querysets.py index e2006892c..590464ccb 100644 --- a/squarelet/organizations/querysets.py +++ b/squarelet/organizations/querysets.py @@ -172,11 +172,8 @@ def for_organization(self, org, client=None): # Squarelet from squarelet.organizations.models.payment import EntitlementGrant - plan_q = Q(plans__organizations=org) - grant_pks = [g.pk for g in EntitlementGrant.objects.matching(org)] - grant_q = Q(grants__in=grant_pks) - - qs = self.filter(plan_q | grant_q) + matching_grants = EntitlementGrant.objects.for_org(org) + qs = self.filter(Q(plans__organizations=org) | Q(grants__in=matching_grants)) if client is not None: qs = qs.filter(client=client) return qs.distinct() @@ -192,21 +189,34 @@ def expired(self, on_date=None): on_date = timezone.now().date() return self.filter(update_on__lte=on_date) - def matching(self, org): - """Return active grants applying to `org`, as a list. + def for_org(self, org): + """Active grants that apply to `org`. Pure SQL — no Python-level loop. - Prefetches `organizations` and `entitlements` so callers can iterate - related rows without further queries. + A grant matches when the org type is compatible AND either: + - the org is explicitly listed in `organizations`, OR + - the grant has at least one rule flag set and every active rule is + satisfied by this org's attributes. """ - org_type_filter = ( + org_type_q = ( Q(for_individuals=True) if org.individual else Q(for_groups=True) ) - candidates = ( - self.active() - .filter(org_type_filter) - .prefetch_related("organizations", "entitlements") + explicit_q = org_type_q & Q(organizations=org) + + at_least_one_rule = ( + Q(require_verified=True) | Q(require_active_subscription=True) + ) + # If the org fails a requirement, exclude grants that set that flag. + verified_ok = ( + Q() if org.verified_journalist else Q(require_verified=False) ) - return [g for g in candidates if g.matches(org)] + sub_ok = ( + Q() + if org.has_active_subscription() + else Q(require_active_subscription=False) + ) + rule_q = org_type_q & at_least_one_rule & verified_ok & sub_ok + + return self.active().filter(explicit_q | rule_q).distinct() class InvitationQuerySet(models.QuerySet): diff --git a/squarelet/organizations/serializers.py b/squarelet/organizations/serializers.py index 67faf8f4c..cbadc2ba3 100644 --- a/squarelet/organizations/serializers.py +++ b/squarelet/organizations/serializers.py @@ -1,5 +1,6 @@ # Third Party import stripe +from django.db.models import Q from rest_framework import serializers, status from rest_framework.exceptions import APIException @@ -93,11 +94,18 @@ def get_entitlements(self, obj): if not client: return [] - # With a client, get entitlements for the org + # Compute matching grants once; reused for both the entitlement query + # and the grant_update_on map below (avoids a second DB round-trip). + matching_grants = EntitlementGrant.objects.for_org(obj).prefetch_related( + "entitlements" + ) entitlements = list( - Entitlement.objects.for_organization(obj, client=client).values( - "pk", "name", "slug", "description", "resources" + Entitlement.objects.filter( + Q(plans__organizations=obj) | Q(grants__in=matching_grants), + client=client, ) + .distinct() + .values("pk", "name", "slug", "description", "resources") ) sub = obj.subscription @@ -107,7 +115,7 @@ def get_entitlements(self, obj): # matching grant's update_on per entitlement. grant_update_on = {} if sub_update_on is None: - for grant in EntitlementGrant.objects.matching(obj): + for grant in matching_grants: if grant.update_on is None: continue for ent in grant.entitlements.all(): diff --git a/squarelet/organizations/tasks.py b/squarelet/organizations/tasks.py index 2b63c1e56..ad2b3645b 100644 --- a/squarelet/organizations/tasks.py +++ b/squarelet/organizations/tasks.py @@ -51,10 +51,13 @@ def restore_organization(): # bump too but contribute no UUIDs to the broadcast. expired_grants = EntitlementGrant.objects.expired(today) grant_uuids = set() - for grant in expired_grants.filter(active=True): - grant_uuids.update( - grant.matching_organizations().values_list("uuid", flat=True) - ) + active_expired = list(expired_grants.filter(active=True)) + if active_expired: + qs_list = [ + g.matching_organizations().values("uuid") for g in active_expired + ] + union_qs = qs_list[0].union(*qs_list[1:]) + grant_uuids = {row["uuid"] for row in union_qs} expired_grants.update(update_on=today + Interval("1 month")) all_uuids = list({*sub_uuids, *grant_uuids}) From f4a6f2dbf52a6230e6f99b9c5624c39e4d83251b Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 26 May 2026 11:39:43 -0400 Subject: [PATCH 6/8] Format --- squarelet/organizations/querysets.py | 14 +++++--------- squarelet/organizations/tasks.py | 4 +--- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/squarelet/organizations/querysets.py b/squarelet/organizations/querysets.py index 590464ccb..611ea92ac 100644 --- a/squarelet/organizations/querysets.py +++ b/squarelet/organizations/querysets.py @@ -190,25 +190,21 @@ def expired(self, on_date=None): return self.filter(update_on__lte=on_date) def for_org(self, org): - """Active grants that apply to `org`. Pure SQL — no Python-level loop. + """Active grants that apply to `org`. A grant matches when the org type is compatible AND either: - the org is explicitly listed in `organizations`, OR - the grant has at least one rule flag set and every active rule is satisfied by this org's attributes. """ - org_type_q = ( - Q(for_individuals=True) if org.individual else Q(for_groups=True) - ) + org_type_q = Q(for_individuals=True) if org.individual else Q(for_groups=True) explicit_q = org_type_q & Q(organizations=org) - at_least_one_rule = ( - Q(require_verified=True) | Q(require_active_subscription=True) + at_least_one_rule = Q(require_verified=True) | Q( + require_active_subscription=True ) # If the org fails a requirement, exclude grants that set that flag. - verified_ok = ( - Q() if org.verified_journalist else Q(require_verified=False) - ) + verified_ok = Q() if org.verified_journalist else Q(require_verified=False) sub_ok = ( Q() if org.has_active_subscription() diff --git a/squarelet/organizations/tasks.py b/squarelet/organizations/tasks.py index ad2b3645b..47135476c 100644 --- a/squarelet/organizations/tasks.py +++ b/squarelet/organizations/tasks.py @@ -53,9 +53,7 @@ def restore_organization(): grant_uuids = set() active_expired = list(expired_grants.filter(active=True)) if active_expired: - qs_list = [ - g.matching_organizations().values("uuid") for g in active_expired - ] + qs_list = [g.matching_organizations().values("uuid") for g in active_expired] union_qs = qs_list[0].union(*qs_list[1:]) grant_uuids = {row["uuid"] for row in union_qs} expired_grants.update(update_on=today + Interval("1 month")) From 94e18de8965c2c5a7df012874d617f58328145a0 Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 26 May 2026 11:58:39 -0400 Subject: [PATCH 7/8] isort --- squarelet/organizations/serializers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/squarelet/organizations/serializers.py b/squarelet/organizations/serializers.py index cbadc2ba3..71cc4245d 100644 --- a/squarelet/organizations/serializers.py +++ b/squarelet/organizations/serializers.py @@ -1,6 +1,8 @@ +# Django +from django.db.models import Q + # Third Party import stripe -from django.db.models import Q from rest_framework import serializers, status from rest_framework.exceptions import APIException From bfb8ed73384e1f6a6ee24e75689751c6b80a8402 Mon Sep 17 00:00:00 2001 From: Allan Lasser Date: Tue, 26 May 2026 14:46:46 -0400 Subject: [PATCH 8/8] Address PR feedback --- .../0064_alter_entitlementgrant_update_on.py | 25 +++++++++++++++++++ squarelet/organizations/models/payment.py | 14 +++++------ 2 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 squarelet/organizations/migrations/0064_alter_entitlementgrant_update_on.py diff --git a/squarelet/organizations/migrations/0064_alter_entitlementgrant_update_on.py b/squarelet/organizations/migrations/0064_alter_entitlementgrant_update_on.py new file mode 100644 index 000000000..156451f85 --- /dev/null +++ b/squarelet/organizations/migrations/0064_alter_entitlementgrant_update_on.py @@ -0,0 +1,25 @@ +# Generated by Django 5.2.12 on 2026-05-26 18:43 + +import squarelet.organizations.models.payment +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("organizations", "0063_entitlementgrant_update_on"), + ] + + operations = [ + migrations.AlterField( + model_name="entitlementgrant", + name="update_on", + field=models.DateField( + blank=True, + default=squarelet.organizations.models.payment.default_grant_update_on, + help_text="Date when this grant's resources next refresh", + null=True, + verbose_name="date update", + ), + ), + ] diff --git a/squarelet/organizations/models/payment.py b/squarelet/organizations/models/payment.py index cef564e88..b9be299ed 100644 --- a/squarelet/organizations/models/payment.py +++ b/squarelet/organizations/models/payment.py @@ -766,6 +766,10 @@ def public(self): return self.plans.filter(public=True).exists() +def default_grant_update_on(): + return timezone.now().date() + relativedelta(months=1) + + class EntitlementGrant(models.Model): """Grants Entitlements to organizations, explicitly or by rule.""" @@ -818,6 +822,7 @@ class EntitlementGrant(models.Model): _("date update"), null=True, blank=True, + default=default_grant_update_on, help_text=_("Date when this grant's resources next refresh"), ) @@ -832,11 +837,6 @@ class Meta: def __str__(self): return self.name - def save(self, *args, **kwargs): - if self.update_on is None: - self.update_on = timezone.now().date() + relativedelta(months=1) - super().save(*args, **kwargs) - def matches(self, org): if not self.active: return False @@ -846,7 +846,7 @@ def matches(self, org): if not org.individual and not self.for_groups: return False # Uses `.all()` so a prefetched `organizations` relation is reused. - if any(o.pk == org.pk for o in self.organizations.all()): + if self.organizations.filter(pk=org.pk).exists(): return True checks = [] if self.require_verified: @@ -893,7 +893,7 @@ def matching_organizations(self): for clause in rule_clauses[1:]: rule_q &= clause return eligible.filter(explicit_q | rule_q).distinct() - return eligible.filter(explicit_q).distinct() + return eligible.filter(explicit_q) class ReceiptEmail(models.Model):