|
| 1 | +""" |
| 2 | + Copyright 2021 Brian Jinwright <[email protected]> |
| 3 | +
|
| 4 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + you may not use this file except in compliance with the License. |
| 6 | + You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | + Unless required by applicable law or agreed to in writing, software |
| 11 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + See the License for the specific language governing permissions and |
| 14 | + limitations under the License. |
| 15 | +
|
| 16 | + Imported from https://github.com/capless/warrant to reduce external dependencies required by this library and just |
| 17 | + use the SRP functions. |
| 18 | +""" |
| 19 | + |
| 20 | +import base64 |
| 21 | +import binascii |
| 22 | +import datetime |
| 23 | +import hashlib |
| 24 | +import hmac |
| 25 | +import os |
| 26 | +import re |
| 27 | + |
| 28 | +import boto3 |
| 29 | +import six |
| 30 | + |
| 31 | + |
| 32 | +class WarrantException(Exception): |
| 33 | + """Base class for all Warrant exceptions""" |
| 34 | + |
| 35 | + |
| 36 | +class ForceChangePasswordException(WarrantException): |
| 37 | + """Raised when the user is forced to change their password""" |
| 38 | + |
| 39 | + |
| 40 | +# https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L22 |
| 41 | +n_hex = ( |
| 42 | + "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1" |
| 43 | + + "29024E088A67CC74020BBEA63B139B22514A08798E3404DD" |
| 44 | + + "EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245" |
| 45 | + + "E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED" |
| 46 | + + "EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D" |
| 47 | + + "C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F" |
| 48 | + + "83655D23DCA3AD961C62F356208552BB9ED529077096966D" |
| 49 | + + "670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B" |
| 50 | + + "E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9" |
| 51 | + + "DE2BCBF6955817183995497CEA956AE515D2261898FA0510" |
| 52 | + + "15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64" |
| 53 | + + "ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7" |
| 54 | + + "ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B" |
| 55 | + + "F12FFA06D98A0864D87602733EC86A64521F2B18177B200C" |
| 56 | + + "BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31" |
| 57 | + + "43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF" |
| 58 | +) |
| 59 | +# https://github.com/aws/amazon-cognito-identity-js/blob/master/src/AuthenticationHelper.js#L49 |
| 60 | +g_hex = "2" |
| 61 | +info_bits = bytearray("Caldera Derived Key", "utf-8") |
| 62 | + |
| 63 | + |
| 64 | +def hash_sha256(buf): |
| 65 | + """AuthenticationHelper.hash""" |
| 66 | + a = hashlib.sha256(buf).hexdigest() |
| 67 | + return (64 - len(a)) * "0" + a |
| 68 | + |
| 69 | + |
| 70 | +def hex_hash(hex_string): |
| 71 | + return hash_sha256(bytearray.fromhex(hex_string)) |
| 72 | + |
| 73 | + |
| 74 | +def hex_to_long(hex_string): |
| 75 | + return int(hex_string, 16) |
| 76 | + |
| 77 | + |
| 78 | +def long_to_hex(long_num): |
| 79 | + return "%x" % long_num |
| 80 | + |
| 81 | + |
| 82 | +def get_random(nbytes): |
| 83 | + random_hex = binascii.hexlify(os.urandom(nbytes)) |
| 84 | + return hex_to_long(random_hex) |
| 85 | + |
| 86 | + |
| 87 | +def pad_hex(long_int): |
| 88 | + """ |
| 89 | + Converts a Long integer (or hex string) to hex format padded with zeroes for hashing |
| 90 | + :param {Long integer|String} long_int Number or string to pad. |
| 91 | + :return {String} Padded hex string. |
| 92 | + """ |
| 93 | + if not isinstance(long_int, six.string_types): |
| 94 | + hash_str = long_to_hex(long_int) |
| 95 | + else: |
| 96 | + hash_str = long_int |
| 97 | + if len(hash_str) % 2 == 1: |
| 98 | + hash_str = "0%s" % hash_str |
| 99 | + elif hash_str[0] in "89ABCDEFabcdef": |
| 100 | + hash_str = "00%s" % hash_str |
| 101 | + return hash_str |
| 102 | + |
| 103 | + |
| 104 | +def compute_hkdf(ikm, salt): |
| 105 | + """ |
| 106 | + Standard hkdf algorithm |
| 107 | + :param {Buffer} ikm Input key material. |
| 108 | + :param {Buffer} salt Salt value. |
| 109 | + :return {Buffer} Strong key material. |
| 110 | + @private |
| 111 | + """ |
| 112 | + prk = hmac.new(salt, ikm, hashlib.sha256).digest() |
| 113 | + info_bits_update = info_bits + bytearray(chr(1), "utf-8") |
| 114 | + hmac_hash = hmac.new(prk, info_bits_update, hashlib.sha256).digest() |
| 115 | + return hmac_hash[:16] |
| 116 | + |
| 117 | + |
| 118 | +def calculate_u(big_a, big_b): |
| 119 | + """ |
| 120 | + Calculate the client's value U which is the hash of A and B |
| 121 | + :param {Long integer} big_a Large A value. |
| 122 | + :param {Long integer} big_b Server B value. |
| 123 | + :return {Long integer} Computed U value. |
| 124 | + """ |
| 125 | + u_hex_hash = hex_hash(pad_hex(big_a) + pad_hex(big_b)) |
| 126 | + return hex_to_long(u_hex_hash) |
| 127 | + |
| 128 | + |
| 129 | +class AWSSRP(object): |
| 130 | + |
| 131 | + NEW_PASSWORD_REQUIRED_CHALLENGE = "NEW_PASSWORD_REQUIRED" |
| 132 | + PASSWORD_VERIFIER_CHALLENGE = "PASSWORD_VERIFIER" |
| 133 | + |
| 134 | + def __init__( |
| 135 | + self, |
| 136 | + username, |
| 137 | + password, |
| 138 | + pool_id, |
| 139 | + client_id, |
| 140 | + pool_region=None, |
| 141 | + client=None, |
| 142 | + client_secret=None, |
| 143 | + ): |
| 144 | + if pool_region is not None and client is not None: |
| 145 | + raise ValueError( |
| 146 | + "pool_region and client should not both be specified " |
| 147 | + "(region should be passed to the boto3 client instead)" |
| 148 | + ) |
| 149 | + |
| 150 | + self.username = username |
| 151 | + self.password = password |
| 152 | + self.pool_id = pool_id |
| 153 | + self.client_id = client_id |
| 154 | + self.client_secret = client_secret |
| 155 | + self.client = ( |
| 156 | + client if client else boto3.client("cognito-idp", region_name=pool_region) |
| 157 | + ) |
| 158 | + self.big_n = hex_to_long(n_hex) |
| 159 | + self.g = hex_to_long(g_hex) |
| 160 | + self.k = hex_to_long(hex_hash("00" + n_hex + "0" + g_hex)) |
| 161 | + self.small_a_value = self.generate_random_small_a() |
| 162 | + self.large_a_value = self.calculate_a() |
| 163 | + |
| 164 | + def generate_random_small_a(self): |
| 165 | + """ |
| 166 | + helper function to generate a random big integer |
| 167 | + :return {Long integer} a random value. |
| 168 | + """ |
| 169 | + random_long_int = get_random(128) |
| 170 | + return random_long_int % self.big_n |
| 171 | + |
| 172 | + def calculate_a(self): |
| 173 | + """ |
| 174 | + Calculate the client's public value A = g^a%N |
| 175 | + with the generated random number a |
| 176 | + :param {Long integer} a Randomly generated small A. |
| 177 | + :return {Long integer} Computed large A. |
| 178 | + """ |
| 179 | + big_a = pow(self.g, self.small_a_value, self.big_n) |
| 180 | + # safety check |
| 181 | + if (big_a % self.big_n) == 0: |
| 182 | + raise ValueError("Safety check for A failed") |
| 183 | + return big_a |
| 184 | + |
| 185 | + def get_password_authentication_key(self, username, password, server_b_value, salt): |
| 186 | + """ |
| 187 | + Calculates the final hkdf based on computed S value, and computed U value and the key |
| 188 | + :param {String} username Username. |
| 189 | + :param {String} password Password. |
| 190 | + :param {Long integer} server_b_value Server B value. |
| 191 | + :param {Long integer} salt Generated salt. |
| 192 | + :return {Buffer} Computed HKDF value. |
| 193 | + """ |
| 194 | + u_value = calculate_u(self.large_a_value, server_b_value) |
| 195 | + username_password = "%s%s:%s" % (self.pool_id.split("_")[1], username, password) |
| 196 | + username_password_hash = hash_sha256(username_password.encode("utf-8")) |
| 197 | + |
| 198 | + x_value = hex_to_long(hex_hash(pad_hex(salt) + username_password_hash)) |
| 199 | + g_mod_pow_xn = pow(self.g, x_value, self.big_n) |
| 200 | + int_value2 = server_b_value - self.k * g_mod_pow_xn |
| 201 | + s_value = pow(int_value2, self.small_a_value + u_value * x_value, self.big_n) |
| 202 | + hkdf = compute_hkdf( |
| 203 | + bytearray.fromhex(pad_hex(s_value)), |
| 204 | + bytearray.fromhex(pad_hex(long_to_hex(u_value))), |
| 205 | + ) |
| 206 | + return hkdf |
| 207 | + |
| 208 | + def get_auth_params(self): |
| 209 | + auth_params = { |
| 210 | + "USERNAME": self.username, |
| 211 | + "SRP_A": long_to_hex(self.large_a_value), |
| 212 | + } |
| 213 | + if self.client_secret is not None: |
| 214 | + auth_params.update( |
| 215 | + { |
| 216 | + "SECRET_HASH": self.get_secret_hash( |
| 217 | + self.username, self.client_id, self.client_secret |
| 218 | + ) |
| 219 | + } |
| 220 | + ) |
| 221 | + return auth_params |
| 222 | + |
| 223 | + @staticmethod |
| 224 | + def get_secret_hash(username, client_id, client_secret): |
| 225 | + message = bytearray(username + client_id, "utf-8") |
| 226 | + hmac_obj = hmac.new(bytearray(client_secret, "utf-8"), message, hashlib.sha256) |
| 227 | + return base64.standard_b64encode(hmac_obj.digest()).decode("utf-8") |
| 228 | + |
| 229 | + def process_challenge(self, challenge_parameters): |
| 230 | + user_id_for_srp = challenge_parameters["USER_ID_FOR_SRP"] |
| 231 | + salt_hex = challenge_parameters["SALT"] |
| 232 | + srp_b_hex = challenge_parameters["SRP_B"] |
| 233 | + secret_block_b64 = challenge_parameters["SECRET_BLOCK"] |
| 234 | + # re strips leading zero from a day number (required by AWS Cognito) |
| 235 | + timestamp = re.sub( |
| 236 | + r" 0(\d) ", |
| 237 | + r" \1 ", |
| 238 | + datetime.datetime.utcnow().strftime("%a %b %d %H:%M:%S UTC %Y"), |
| 239 | + ) |
| 240 | + hkdf = self.get_password_authentication_key( |
| 241 | + user_id_for_srp, self.password, hex_to_long(srp_b_hex), salt_hex |
| 242 | + ) |
| 243 | + secret_block_bytes = base64.standard_b64decode(secret_block_b64) |
| 244 | + msg = ( |
| 245 | + bytearray(self.pool_id.split("_")[1], "utf-8") |
| 246 | + + bytearray(user_id_for_srp, "utf-8") |
| 247 | + + bytearray(secret_block_bytes) |
| 248 | + + bytearray(timestamp, "utf-8") |
| 249 | + ) |
| 250 | + hmac_obj = hmac.new(hkdf, msg, digestmod=hashlib.sha256) |
| 251 | + signature_string = base64.standard_b64encode(hmac_obj.digest()) |
| 252 | + response = { |
| 253 | + "TIMESTAMP": timestamp, |
| 254 | + "USERNAME": user_id_for_srp, |
| 255 | + "PASSWORD_CLAIM_SECRET_BLOCK": secret_block_b64, |
| 256 | + "PASSWORD_CLAIM_SIGNATURE": signature_string.decode("utf-8"), |
| 257 | + } |
| 258 | + if self.client_secret is not None: |
| 259 | + response.update( |
| 260 | + { |
| 261 | + "SECRET_HASH": self.get_secret_hash( |
| 262 | + self.username, self.client_id, self.client_secret |
| 263 | + ) |
| 264 | + } |
| 265 | + ) |
| 266 | + return response |
| 267 | + |
| 268 | + def authenticate_user(self, client=None): |
| 269 | + boto_client = self.client or client |
| 270 | + auth_params = self.get_auth_params() |
| 271 | + response = boto_client.initiate_auth( |
| 272 | + AuthFlow="USER_SRP_AUTH", |
| 273 | + AuthParameters=auth_params, |
| 274 | + ClientId=self.client_id, |
| 275 | + ) |
| 276 | + if response["ChallengeName"] == self.PASSWORD_VERIFIER_CHALLENGE: |
| 277 | + challenge_response = self.process_challenge(response["ChallengeParameters"]) |
| 278 | + tokens = boto_client.respond_to_auth_challenge( |
| 279 | + ClientId=self.client_id, |
| 280 | + ChallengeName=self.PASSWORD_VERIFIER_CHALLENGE, |
| 281 | + ChallengeResponses=challenge_response, |
| 282 | + ) |
| 283 | + |
| 284 | + if tokens.get("ChallengeName") == self.NEW_PASSWORD_REQUIRED_CHALLENGE: |
| 285 | + raise ForceChangePasswordException( |
| 286 | + "Change password before authenticating" |
| 287 | + ) |
| 288 | + |
| 289 | + return tokens |
| 290 | + else: |
| 291 | + raise NotImplementedError( |
| 292 | + "The %s challenge is not supported" % response["ChallengeName"] |
| 293 | + ) |
0 commit comments