|
| 1 | +import base64 |
| 2 | +import hashlib |
| 3 | +import logging |
| 4 | +from datetime import datetime, timedelta, timezone |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import jwt |
| 8 | +from cryptography.hazmat.primitives import serialization |
| 9 | +from cryptography.hazmat.primitives.asymmetric import types |
| 10 | + |
| 11 | +logger = logging.getLogger(__name__) |
| 12 | + |
| 13 | +ISSUER = "iss" |
| 14 | +EXPIRE_TIME = "exp" |
| 15 | +ISSUE_TIME = "iat" |
| 16 | +SUBJECT = "sub" |
| 17 | + |
| 18 | + |
| 19 | +class JWTGenerator: |
| 20 | + """ |
| 21 | + Creates and signs a JWT with the specified private key file, username, and account identifier. The JWTGenerator |
| 22 | + keeps the generated token and only regenerates the token if a specified period of time has passed. |
| 23 | + """ |
| 24 | + |
| 25 | + _DEFAULT_LIFETIME = timedelta(minutes=59) # The tokens will have a 59-minute lifetime |
| 26 | + _DEFAULT_RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes |
| 27 | + ALGORITHM = "RS256" # Tokens will be generated using RSA with SHA256 |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + account: str, |
| 32 | + user: str, |
| 33 | + private_key: types.PRIVATE_KEY_TYPES, |
| 34 | + lifetime: Optional[timedelta] = None, |
| 35 | + renewal_delay: Optional[timedelta] = None, |
| 36 | + ) -> None: |
| 37 | + """ |
| 38 | + Create a new JWTGenerator object. |
| 39 | +
|
| 40 | + Args: |
| 41 | + account: The account identifier. |
| 42 | + user: The username. |
| 43 | + private_key: The private key used to sign the JWT. |
| 44 | + lifetime: The lifetime of the token. |
| 45 | + renewal_delay: The time before the token expires to renew it. |
| 46 | + """ |
| 47 | + |
| 48 | + # Construct the fully qualified name of the user in uppercase. |
| 49 | + self.account = JWTGenerator._prepare_account_name_for_jwt(account) |
| 50 | + self.user = user.upper() |
| 51 | + self.qualified_username = self.account + "." + self.user |
| 52 | + self.private_key = private_key |
| 53 | + self.public_key_fp = JWTGenerator._calculate_public_key_fingerprint(self.private_key) |
| 54 | + |
| 55 | + self.issuer = self.qualified_username + "." + self.public_key_fp |
| 56 | + self.lifetime = lifetime or JWTGenerator._DEFAULT_LIFETIME |
| 57 | + self.renewal_delay = renewal_delay or JWTGenerator._DEFAULT_RENEWAL_DELTA |
| 58 | + self.renew_time = datetime.now(timezone.utc) |
| 59 | + self.token: Optional[str] = None |
| 60 | + |
| 61 | + logger.info( |
| 62 | + """Creating JWTGenerator with arguments |
| 63 | + account : %s, user : %s, lifetime : %s, renewal_delay : %s""", |
| 64 | + self.account, |
| 65 | + self.user, |
| 66 | + self.lifetime, |
| 67 | + self.renewal_delay, |
| 68 | + ) |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + def _prepare_account_name_for_jwt(raw_account: str) -> str: |
| 72 | + account = raw_account |
| 73 | + if ".global" not in account: |
| 74 | + # Handle the general case. |
| 75 | + idx = account.find(".") |
| 76 | + if idx > 0: |
| 77 | + account = account[0:idx] |
| 78 | + else: |
| 79 | + # Handle the replication case. |
| 80 | + idx = account.find("-") |
| 81 | + if idx > 0: |
| 82 | + account = account[0:idx] |
| 83 | + # Use uppercase for the account identifier. |
| 84 | + return account.upper() |
| 85 | + |
| 86 | + def get_token(self) -> str: |
| 87 | + now = datetime.now(timezone.utc) # Fetch the current time |
| 88 | + if self.token is not None and self.renew_time > now: |
| 89 | + return self.token |
| 90 | + |
| 91 | + # If the token has expired or doesn't exist, regenerate the token. |
| 92 | + logger.info( |
| 93 | + "Generating a new token because the present time (%s) is later than the renewal time (%s)", |
| 94 | + now, |
| 95 | + self.renew_time, |
| 96 | + ) |
| 97 | + # Calculate the next time we need to renew the token. |
| 98 | + self.renew_time = now + self.renewal_delay |
| 99 | + |
| 100 | + # Create our payload |
| 101 | + payload = { |
| 102 | + # Set the issuer to the fully qualified username concatenated with the public key fingerprint. |
| 103 | + ISSUER: self.issuer, |
| 104 | + # Set the subject to the fully qualified username. |
| 105 | + SUBJECT: self.qualified_username, |
| 106 | + # Set the issue time to now. |
| 107 | + ISSUE_TIME: now, |
| 108 | + # Set the expiration time, based on the lifetime specified for this object. |
| 109 | + EXPIRE_TIME: now + self.lifetime, |
| 110 | + } |
| 111 | + |
| 112 | + # Regenerate the actual token |
| 113 | + token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) |
| 114 | + # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string. |
| 115 | + # If the token is a byte string, convert it to a string. |
| 116 | + if isinstance(token, bytes): |
| 117 | + token = token.decode("utf-8") |
| 118 | + self.token = token |
| 119 | + logger.info( |
| 120 | + "Generated a JWT with the following payload: %s", |
| 121 | + jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]), |
| 122 | + ) |
| 123 | + |
| 124 | + return token |
| 125 | + |
| 126 | + @staticmethod |
| 127 | + def _calculate_public_key_fingerprint(private_key: types.PRIVATE_KEY_TYPES) -> str: |
| 128 | + # Get the raw bytes of public key. |
| 129 | + public_key_raw = private_key.public_key().public_bytes( |
| 130 | + serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo |
| 131 | + ) |
| 132 | + |
| 133 | + # Get the sha256 hash of the raw bytes. |
| 134 | + sha256hash = hashlib.sha256() |
| 135 | + sha256hash.update(public_key_raw) |
| 136 | + |
| 137 | + # Base64-encode the value and prepend the prefix 'SHA256:'. |
| 138 | + public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8") |
| 139 | + logger.info("Public key fingerprint is %s", public_key_fp) |
| 140 | + |
| 141 | + return public_key_fp |
0 commit comments