|
12 | 12 | from datetime import timedelta
|
13 | 13 | from functools import cached_property
|
14 | 14 | from hashlib import sha256
|
| 15 | +from typing import Set, List, Optional, Tuple, Dict |
15 | 16 |
|
16 | 17 | import dns
|
17 | 18 | import psl_dns
|
18 | 19 | import rest_framework.authtoken.models
|
| 20 | +from cryptography import x509, hazmat |
19 | 21 | from django.conf import settings
|
20 | 22 | from django.contrib.auth.hashers import make_password
|
21 | 23 | from django.contrib.auth.models import AbstractBaseUser, AnonymousUser, BaseUserManager
|
@@ -946,3 +948,209 @@ def verify(self, solution: str):
|
946 | 948 | and
|
947 | 949 | age <= settings.CAPTCHA_VALIDITY_PERIOD # not expired
|
948 | 950 | )
|
| 951 | + |
| 952 | + |
| 953 | +class Identity(models.Model): |
| 954 | + rr_type = None |
| 955 | + |
| 956 | + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) |
| 957 | + name = models.CharField(max_length=24, default="") |
| 958 | + created = models.DateTimeField(auto_now_add=True) |
| 959 | + owner = models.ForeignKey(User, on_delete=models.PROTECT, related_name='identities') |
| 960 | + default_ttl = models.PositiveIntegerField(default=300) |
| 961 | + |
| 962 | + class Meta: |
| 963 | + abstract = True |
| 964 | + |
| 965 | + def get_record_contents(self) -> List[str]: |
| 966 | + raise NotImplementedError |
| 967 | + |
| 968 | + def save_rrs(self): |
| 969 | + raise NotImplementedError |
| 970 | + |
| 971 | + def save(self, *args, **kwargs): |
| 972 | + self.save_rrs() |
| 973 | + return super().save(*args, **kwargs) |
| 974 | + |
| 975 | + def delete_rrs(self): |
| 976 | + raise NotImplementedError |
| 977 | + |
| 978 | + def delete(self, using=None, keep_parents=False): |
| 979 | + # TODO this will delete also RRs that may be covered by other identities |
| 980 | + self.delete_rrs() |
| 981 | + return super().delete(using, keep_parents) |
| 982 | + |
| 983 | + def get_or_create_rr_set(self, domain: Domain, subname: str) -> RRset: |
| 984 | + try: |
| 985 | + return RRset.objects.get(domain=domain, subname=subname, type=self.rr_type) |
| 986 | + except RRset.DoesNotExist: |
| 987 | + # TODO save this RRset? |
| 988 | + return RRset(domain=domain, subname=subname, type=self.rr_type, ttl=self.default_ttl) |
| 989 | + |
| 990 | + @staticmethod |
| 991 | + def get_or_create_rr(rrset: RRset, content: str) -> RR: |
| 992 | + try: |
| 993 | + return RR.objects.get(rrset=rrset, content=content) |
| 994 | + except RR.DoesNotExist: |
| 995 | + return RR(rrset=rrset, content=content) |
| 996 | + |
| 997 | + |
| 998 | +class TLSIdentity(Identity): |
| 999 | + rr_type = 'TLSA' |
| 1000 | + |
| 1001 | + class CertificateUsage(models.IntegerChoices): |
| 1002 | + CA_CONSTRAINT = 0 |
| 1003 | + SERVICE_CERTIFICATE_CONSTRAINT = 1 |
| 1004 | + TRUST_ANCHOR_ASSERTION = 2 |
| 1005 | + DOMAIN_ISSUED_CERTIFICATE = 3 |
| 1006 | + |
| 1007 | + class Selector(models.IntegerChoices): |
| 1008 | + FULL_CERTIFICATE = 0 |
| 1009 | + SUBJECT_PUBLIC_KEY_INFO = 1 |
| 1010 | + |
| 1011 | + class MatchingType(models.IntegerChoices): |
| 1012 | + NO_HASH_USED = 0 |
| 1013 | + SHA256 = 1 |
| 1014 | + SHA512 = 2 |
| 1015 | + |
| 1016 | + class Protocol(models.TextChoices): |
| 1017 | + TCP = 'tcp' |
| 1018 | + UDP = 'udp' |
| 1019 | + SCTP = 'sctp' |
| 1020 | + |
| 1021 | + certificate = models.TextField() |
| 1022 | + |
| 1023 | + tlsa_selector = models.IntegerField(choices=Selector.choices, default=Selector.SUBJECT_PUBLIC_KEY_INFO) |
| 1024 | + tlsa_matching_type = models.IntegerField(choices=MatchingType.choices, default=MatchingType.SHA256) |
| 1025 | + tlsa_certificate_usage = models.IntegerField(choices=CertificateUsage.choices, |
| 1026 | + default=CertificateUsage.DOMAIN_ISSUED_CERTIFICATE) |
| 1027 | + |
| 1028 | + port = models.IntegerField(default=443) |
| 1029 | + protocol = models.TextField(choices=Protocol.choices, default=Protocol.TCP) |
| 1030 | + |
| 1031 | + scheduled_removal = models.DateTimeField(null=True) |
| 1032 | + |
| 1033 | + def __init__(self, *args, **kwargs): |
| 1034 | + super().__init__(*args, **kwargs) |
| 1035 | + if 'not_valid_after' not in kwargs: |
| 1036 | + self.scheduled_removal = self.not_valid_after |
| 1037 | + |
| 1038 | + def get_record_contents(self) -> List[str]: |
| 1039 | + # choose hash function |
| 1040 | + if self.tlsa_matching_type == self.MatchingType.SHA256: |
| 1041 | + hash_function = hazmat.primitives.hashes.SHA256() |
| 1042 | + elif self.tlsa_matching_type == self.MatchingType.SHA512: |
| 1043 | + hash_function = hazmat.primitives.hashes.SHA512() |
| 1044 | + else: |
| 1045 | + raise NotImplementedError |
| 1046 | + |
| 1047 | + # choose data to hash |
| 1048 | + if self.tlsa_selector == self.Selector.SUBJECT_PUBLIC_KEY_INFO: |
| 1049 | + to_be_hashed = self._cert.public_key().public_bytes( |
| 1050 | + hazmat.primitives.serialization.Encoding.DER, |
| 1051 | + hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo |
| 1052 | + ) |
| 1053 | + else: |
| 1054 | + raise NotImplementedError |
| 1055 | + |
| 1056 | + # compute the hash |
| 1057 | + h = hazmat.primitives.hashes.Hash(hash_function) |
| 1058 | + h.update(to_be_hashed) |
| 1059 | + hash = h.finalize().hex() |
| 1060 | + |
| 1061 | + # create TLSA record content |
| 1062 | + return [f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"] |
| 1063 | + |
| 1064 | + @property |
| 1065 | + def _cert(self) -> x509.Certificate: |
| 1066 | + return x509.load_pem_x509_certificate(self.certificate.encode()) |
| 1067 | + |
| 1068 | + @property |
| 1069 | + def fingerprint(self) -> str: |
| 1070 | + return self._cert.fingerprint(hazmat.primitives.hashes.SHA256()).hex() |
| 1071 | + |
| 1072 | + @property |
| 1073 | + def subject_names(self) -> Set[str]: |
| 1074 | + subject_names = { |
| 1075 | + x.value for x in |
| 1076 | + self._cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) |
| 1077 | + } |
| 1078 | + |
| 1079 | + try: |
| 1080 | + subject_alternative_names = { |
| 1081 | + x for x in |
| 1082 | + self._cert.extensions.get_extension_for_oid( |
| 1083 | + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value.get_values_for_type(x509.DNSName) |
| 1084 | + } |
| 1085 | + except x509.extensions.ExtensionNotFound: |
| 1086 | + subject_alternative_names = set() |
| 1087 | + |
| 1088 | + return subject_names | subject_alternative_names |
| 1089 | + |
| 1090 | + @staticmethod |
| 1091 | + def get_closest_ancestor(domain_name, owner: User) -> Optional[Domain]: |
| 1092 | + # TODO move to Domain? |
| 1093 | + labels = domain_name.split('.') |
| 1094 | + ancestor_names = ['.'.join(labels[i:]) for i in range(len(labels))] |
| 1095 | + for ancestor_name in ancestor_names: # TODO do this with one query |
| 1096 | + try: |
| 1097 | + return Domain.objects.get(name=ancestor_name, owner=owner) |
| 1098 | + except Domain.DoesNotExist: |
| 1099 | + continue |
| 1100 | + return None |
| 1101 | + |
| 1102 | + def domains_subnames(self) -> Set[Tuple[Domain, str]]: |
| 1103 | + domains_subnames = set() |
| 1104 | + for name in self.subject_names: |
| 1105 | + # cut off any wildcard prefix |
| 1106 | + name = name.lstrip('*').lstrip('.') |
| 1107 | + |
| 1108 | + # filter names for valid domain names |
| 1109 | + try: |
| 1110 | + validate_domain_name[1](name) |
| 1111 | + except ValidationError: |
| 1112 | + continue |
| 1113 | + |
| 1114 | + # find user-owned parent domain |
| 1115 | + domain = self.get_closest_ancestor(name, self.owner) |
| 1116 | + if not domain: |
| 1117 | + continue |
| 1118 | + subname = name[:-len(domain.name)].rstrip('.') |
| 1119 | + |
| 1120 | + # return subname, domain pair |
| 1121 | + domains_subnames.add((domain, f"_{self.port:n}._{self.protocol}.{subname}".rstrip('.'))) |
| 1122 | + return domains_subnames |
| 1123 | + |
| 1124 | + def get_rrsets(self) -> List[RRset]: |
| 1125 | + rrsets = [] |
| 1126 | + for domain, subname in self.domains_subnames(): |
| 1127 | + rrsets.append(self.get_or_create_rr_set(domain, subname)) |
| 1128 | + return rrsets |
| 1129 | + |
| 1130 | + def get_rrs(self) -> List[RR]: |
| 1131 | + rrs = [] |
| 1132 | + for domain, subname in self.domains_subnames(): |
| 1133 | + rrset = self.get_or_create_rr_set(domain, subname) |
| 1134 | + for content in self.get_record_contents(): |
| 1135 | + rrs.append(self.get_or_create_rr(rrset=rrset, content=content)) |
| 1136 | + return rrs |
| 1137 | + |
| 1138 | + def save_rrs(self): |
| 1139 | + for rr in self.get_rrs(): |
| 1140 | + rr.rrset.save() |
| 1141 | + rr.save() |
| 1142 | + |
| 1143 | + def delete_rrs(self): |
| 1144 | + for domain, subname in self.domains_subnames(): |
| 1145 | + rrset = self.get_or_create_rr_set(domain, subname) |
| 1146 | + rrset.records.filter(content__in=self.get_record_contents()).delete() |
| 1147 | + if not len(rrset.records.all()): |
| 1148 | + rrset.delete() |
| 1149 | + |
| 1150 | + @property |
| 1151 | + def not_valid_before(self): |
| 1152 | + return self._cert.not_valid_before |
| 1153 | + |
| 1154 | + @property |
| 1155 | + def not_valid_after(self): |
| 1156 | + return self._cert.not_valid_after |
0 commit comments