| 
3 | 3 | import unittest  | 
4 | 4 | from enum import Enum  | 
5 | 5 | from http import HTTPStatus  | 
 | 6 | +from types import SimpleNamespace  | 
6 | 7 | from unittest import mock  | 
7 | 8 | from unittest.mock import patch  | 
8 | 9 | 
 
  | 
@@ -454,6 +455,148 @@ def test_exchange_access_key(self):  | 
454 | 455 |                 timeout=DEFAULT_TIMEOUT_SECONDS,  | 
455 | 456 |             )  | 
456 | 457 | 
 
  | 
 | 458 | +    def test_exchange_token_success_and_empty_code(self):  | 
 | 459 | +        auth = Auth(  | 
 | 460 | +            self.dummy_project_id,  | 
 | 461 | +            self.public_key_dict,  | 
 | 462 | +            http_client=self.make_http_client(),  | 
 | 463 | +        )  | 
 | 464 | + | 
 | 465 | +        # Empty code -> error  | 
 | 466 | +        with self.assertRaises(AuthException):  | 
 | 467 | +            auth.exchange_token("/oauth/exchange", "")  | 
 | 468 | + | 
 | 469 | +        # Success path  | 
 | 470 | +        with patch("requests.post") as mock_post:  | 
 | 471 | +            net_resp = mock.Mock()  | 
 | 472 | +            net_resp.ok = True  | 
 | 473 | +            net_resp.cookies = {"DSR": "cookie_token"}  | 
 | 474 | +            # Make validator return claims  | 
 | 475 | +            auth._validate_token = lambda token, audience=None: {  | 
 | 476 | +                "iss": "https://issuer/PX",  | 
 | 477 | +                "sub": "user-x",  | 
 | 478 | +            }  | 
 | 479 | +            net_resp.json.return_value = {  | 
 | 480 | +                "sessionJwt": "s1",  | 
 | 481 | +                "refreshJwt": "r1",  | 
 | 482 | +                "user": {"id": "user-x"},  | 
 | 483 | +                "firstSeen": True,  | 
 | 484 | +            }  | 
 | 485 | +            mock_post.return_value = net_resp  | 
 | 486 | +            out = auth.exchange_token("/oauth/exchange", code="abc")  | 
 | 487 | +            self.assertEqual(out["projectId"], "PX")  | 
 | 488 | +            self.assertEqual(out["userId"], "user-x")  | 
 | 489 | + | 
 | 490 | +    def test_validate_session_success(self):  | 
 | 491 | +        auth = Auth(  | 
 | 492 | +            self.dummy_project_id,  | 
 | 493 | +            self.public_key_dict,  | 
 | 494 | +            http_client=self.make_http_client(),  | 
 | 495 | +        )  | 
 | 496 | +        # Stub validator to bypass network  | 
 | 497 | +        auth._validate_token = lambda token, audience=None: {  | 
 | 498 | +            "iss": "P123",  | 
 | 499 | +            "sub": "u123",  | 
 | 500 | +            "permissions": ["p1"],  | 
 | 501 | +            "roles": ["r1"],  | 
 | 502 | +            "tenants": {"t1": {}},  | 
 | 503 | +        }  | 
 | 504 | +        res = auth.validate_session("token-session")  | 
 | 505 | +        self.assertEqual(res["projectId"], "P123")  | 
 | 506 | +        self.assertEqual(res["userId"], "u123")  | 
 | 507 | +        self.assertEqual(res["permissions"], ["p1"])  | 
 | 508 | +        self.assertIn(SESSION_TOKEN_NAME, res)  | 
 | 509 | + | 
 | 510 | +    def test_select_tenant_success(self):  | 
 | 511 | +        auth = Auth(  | 
 | 512 | +            self.dummy_project_id,  | 
 | 513 | +            self.public_key_dict,  | 
 | 514 | +            http_client=self.make_http_client(),  | 
 | 515 | +        )  | 
 | 516 | +        # Missing refresh token  | 
 | 517 | +        with self.assertRaises(AuthException):  | 
 | 518 | +            auth.select_tenant("tenant1", "")  | 
 | 519 | + | 
 | 520 | +        # Success network path  | 
 | 521 | +        with patch("requests.post") as mock_post:  | 
 | 522 | +            net_resp = mock.Mock()  | 
 | 523 | +            net_resp.ok = True  | 
 | 524 | +            net_resp.cookies = {"DSR": "cookie_r"}  | 
 | 525 | +            # validator stub  | 
 | 526 | +            auth._validate_token = lambda token, audience=None: {  | 
 | 527 | +                "iss": "P77",  | 
 | 528 | +                "sub": "u77",  | 
 | 529 | +            }  | 
 | 530 | +            net_resp.json.return_value = {  | 
 | 531 | +                "sessionJwt": "s77",  | 
 | 532 | +                "refreshJwt": "r77",  | 
 | 533 | +            }  | 
 | 534 | +            mock_post.return_value = net_resp  | 
 | 535 | +            out = auth.select_tenant("tenant1", refresh_token="r0")  | 
 | 536 | +            self.assertEqual(out["projectId"], "P77")  | 
 | 537 | +            self.assertIn(SESSION_TOKEN_NAME, out)  | 
 | 538 | + | 
 | 539 | +    def test_compose_url_invalid_method(self):  | 
 | 540 | +        class Dummy(Enum):  | 
 | 541 | +            X = 1  | 
 | 542 | + | 
 | 543 | +        with self.assertRaises(AuthException):  | 
 | 544 | +            Auth.compose_url("/base", Dummy.X)  | 
 | 545 | + | 
 | 546 | +    def test_validate_token_header_errors(self):  | 
 | 547 | +        auth = Auth(  | 
 | 548 | +            self.dummy_project_id,  | 
 | 549 | +            self.public_key_dict,  | 
 | 550 | +            http_client=self.make_http_client(),  | 
 | 551 | +        )  | 
 | 552 | +        # Empty token  | 
 | 553 | +        with self.assertRaises(AuthException):  | 
 | 554 | +            auth._validate_token("")  | 
 | 555 | + | 
 | 556 | +        # Garbage token -> header parse error  | 
 | 557 | +        with self.assertRaises(AuthException):  | 
 | 558 | +            auth._validate_token("not-a-jwt")  | 
 | 559 | + | 
 | 560 | +        # Missing alg -> mock header dict without alg  | 
 | 561 | +        with patch("descope.auth.jwt.get_unverified_header") as mock_hdr:  | 
 | 562 | +            mock_hdr.return_value = {"kid": "kid1"}  | 
 | 563 | +            with self.assertRaises(AuthException) as cm:  | 
 | 564 | +                auth._validate_token("any.token.value")  | 
 | 565 | +            self.assertIn("missing property: alg", str(cm.exception).lower())  | 
 | 566 | + | 
 | 567 | +        # Missing kid -> mock header dict without kid  | 
 | 568 | +        with patch("descope.auth.jwt.get_unverified_header") as mock_hdr:  | 
 | 569 | +            mock_hdr.return_value = {"alg": "ES384"}  | 
 | 570 | +            with self.assertRaises(AuthException) as cm2:  | 
 | 571 | +                auth._validate_token("any.token.value")  | 
 | 572 | +            self.assertIn("missing property: kid", str(cm2.exception).lower())  | 
 | 573 | + | 
 | 574 | +        # Algorithm mismatch after fetching keys (kid found but alg different)  | 
 | 575 | +        with patch("descope.auth.jwt.get_unverified_header") as mock_hdr:  | 
 | 576 | +            mock_hdr.return_value = {  | 
 | 577 | +                "alg": "RS256",  | 
 | 578 | +                "kid": self.public_key_dict["kid"],  | 
 | 579 | +            }  | 
 | 580 | +            with self.assertRaises(AuthException) as cm3:  | 
 | 581 | +                auth._validate_token("any.token.value")  | 
 | 582 | +            self.assertIn("does not match", str(cm3.exception))  | 
 | 583 | + | 
 | 584 | +    def test_extract_masked_address_default(self):  | 
 | 585 | +        # Unknown method should return empty string  | 
 | 586 | +        class DummyMethod(Enum):  | 
 | 587 | +            OTHER = 999  | 
 | 588 | + | 
 | 589 | +        self.assertEqual(Auth.extract_masked_address({}, DummyMethod.OTHER), "")  | 
 | 590 | + | 
 | 591 | +    def test_extract_masked_address_known_methods(self):  | 
 | 592 | +        resp = {"maskedPhone": "+1-***-***-1234", "maskedEmail": "a***@b.com"}  | 
 | 593 | +        self.assertEqual(  | 
 | 594 | +            Auth.extract_masked_address(resp, DeliveryMethod.SMS), "+1-***-***-1234"  | 
 | 595 | +        )  | 
 | 596 | +        self.assertEqual(  | 
 | 597 | +            Auth.extract_masked_address(resp, DeliveryMethod.EMAIL), "a***@b.com"  | 
 | 598 | +        )  | 
 | 599 | + | 
457 | 600 |     def test_adjust_properties(self):  | 
458 | 601 |         self.assertEqual(  | 
459 | 602 |             Auth.adjust_properties(self, jwt_response={}, user_jwt={}),  | 
@@ -805,6 +948,193 @@ def test_raise_from_response(self):  | 
805 | 948 |                 """{"errorCode":"E062108","errorDescription":"User not found","errorMessage":"Cannot find user"}""",  | 
806 | 949 |             )  | 
807 | 950 | 
 
  | 
 | 951 | +    def test_http_client_authorization_header_variants(self):  | 
 | 952 | +        # Base client without management key  | 
 | 953 | +        client = self.make_http_client()  | 
 | 954 | +        headers = client.get_default_headers()  | 
 | 955 | +        self.assertEqual(headers["Authorization"], f"Bearer {self.dummy_project_id}")  | 
 | 956 | + | 
 | 957 | +        # With password/pswd only  | 
 | 958 | +        headers = client.get_default_headers(pswd="sekret")  | 
 | 959 | +        self.assertEqual(  | 
 | 960 | +            headers["Authorization"], f"Bearer {self.dummy_project_id}:sekret"  | 
 | 961 | +        )  | 
 | 962 | + | 
 | 963 | +        # With management key only  | 
 | 964 | +        client2 = self.make_http_client(management_key="mkey")  | 
 | 965 | +        headers2 = client2.get_default_headers()  | 
 | 966 | +        self.assertEqual(  | 
 | 967 | +            headers2["Authorization"], f"Bearer {self.dummy_project_id}:mkey"  | 
 | 968 | +        )  | 
 | 969 | + | 
 | 970 | +        # With both pswd and management key  | 
 | 971 | +        headers3 = client2.get_default_headers(pswd="sekret")  | 
 | 972 | +        self.assertEqual(  | 
 | 973 | +            headers3["Authorization"],  | 
 | 974 | +            f"Bearer {self.dummy_project_id}:sekret:mkey",  | 
 | 975 | +        )  | 
 | 976 | + | 
 | 977 | +    def test_compose_url_success(self):  | 
 | 978 | +        base = "/otp/send"  | 
 | 979 | +        self.assertEqual(Auth.compose_url(base, DeliveryMethod.EMAIL), f"{base}/email")  | 
 | 980 | +        self.assertEqual(Auth.compose_url(base, DeliveryMethod.SMS), f"{base}/sms")  | 
 | 981 | +        self.assertEqual(Auth.compose_url(base, DeliveryMethod.VOICE), f"{base}/voice")  | 
 | 982 | +        self.assertEqual(  | 
 | 983 | +            Auth.compose_url(base, DeliveryMethod.WHATSAPP), f"{base}/whatsapp"  | 
 | 984 | +        )  | 
 | 985 | + | 
 | 986 | +    def test_internal_rate_limit_helpers(self):  | 
 | 987 | +        auth = Auth(  | 
 | 988 | +            self.dummy_project_id,  | 
 | 989 | +            self.public_key_dict,  | 
 | 990 | +            http_client=self.make_http_client(),  | 
 | 991 | +        )  | 
 | 992 | + | 
 | 993 | +        class Resp:  | 
 | 994 | +            def __init__(self, ok, status_code, body, headers):  | 
 | 995 | +                self.ok = ok  | 
 | 996 | +                self.status_code = status_code  | 
 | 997 | +                self._body = body  | 
 | 998 | +                self.headers = headers  | 
 | 999 | +                self.text = "txt"  | 
 | 1000 | + | 
 | 1001 | +            def json(self):  | 
 | 1002 | +                return self._body  | 
 | 1003 | + | 
 | 1004 | +        # _parse_retry_after  | 
 | 1005 | +        self.assertEqual(  | 
 | 1006 | +            auth._parse_retry_after({API_RATE_LIMIT_RETRY_AFTER_HEADER: "7"}), 7  | 
 | 1007 | +        )  | 
 | 1008 | +        self.assertEqual(  | 
 | 1009 | +            auth._parse_retry_after({API_RATE_LIMIT_RETRY_AFTER_HEADER: "x"}), 0  | 
 | 1010 | +        )  | 
 | 1011 | + | 
 | 1012 | +        # _raise_rate_limit_exception with valid JSON  | 
 | 1013 | +        r1 = Resp(  | 
 | 1014 | +            ok=False,  | 
 | 1015 | +            status_code=429,  | 
 | 1016 | +            body={  | 
 | 1017 | +                "errorCode": "E130429",  | 
 | 1018 | +                "errorDescription": "https://docs",  | 
 | 1019 | +                "errorMessage": "rate",  | 
 | 1020 | +            },  | 
 | 1021 | +            headers={API_RATE_LIMIT_RETRY_AFTER_HEADER: "3"},  | 
 | 1022 | +        )  | 
 | 1023 | +        with self.assertRaises(RateLimitException) as cm:  | 
 | 1024 | +            auth._raise_rate_limit_exception(r1)  | 
 | 1025 | +        ex = cm.exception  | 
 | 1026 | +        self.assertEqual(ex.status_code, "E130429")  | 
 | 1027 | +        self.assertEqual(ex.error_type, ERROR_TYPE_API_RATE_LIMIT)  | 
 | 1028 | +        self.assertEqual(ex.error_description, "https://docs")  | 
 | 1029 | +        self.assertEqual(ex.error_message, "rate")  | 
 | 1030 | +        self.assertEqual(  | 
 | 1031 | +            ex.rate_limit_parameters, {API_RATE_LIMIT_RETRY_AFTER_HEADER: 3}  | 
 | 1032 | +        )  | 
 | 1033 | + | 
 | 1034 | +        # _raise_rate_limit_exception with invalid JSON  | 
 | 1035 | +        r2 = Resp(False, 429, "not-a-dict", {API_RATE_LIMIT_RETRY_AFTER_HEADER: "x"})  | 
 | 1036 | +        with self.assertRaises(RateLimitException) as cm2:  | 
 | 1037 | +            auth._raise_rate_limit_exception(r2)  | 
 | 1038 | +        ex2 = cm2.exception  | 
 | 1039 | +        self.assertEqual(ex2.status_code, HTTPStatus.TOO_MANY_REQUESTS)  | 
 | 1040 | +        self.assertEqual(ex2.error_type, ERROR_TYPE_API_RATE_LIMIT)  | 
 | 1041 | +        self.assertEqual(ex2.error_description, ERROR_TYPE_API_RATE_LIMIT)  | 
 | 1042 | +        self.assertEqual(ex2.error_message, ERROR_TYPE_API_RATE_LIMIT)  | 
 | 1043 | + | 
 | 1044 | +        # _raise_from_response with non-429  | 
 | 1045 | +        r3 = Resp(False, 400, {}, {})  | 
 | 1046 | +        with self.assertRaises(AuthException):  | 
 | 1047 | +            auth._raise_from_response(r3)  | 
 | 1048 | + | 
 | 1049 | +        # _raise_from_response with 429 invokes rate-limit handler  | 
 | 1050 | +        r4 = Resp(  | 
 | 1051 | +            False,  | 
 | 1052 | +            429,  | 
 | 1053 | +            {"errorCode": "E130", "errorDescription": "d", "errorMessage": "m"},  | 
 | 1054 | +            {API_RATE_LIMIT_RETRY_AFTER_HEADER: "2"},  | 
 | 1055 | +        )  | 
 | 1056 | +        with self.assertRaises(RateLimitException):  | 
 | 1057 | +            auth._raise_from_response(r4)  | 
 | 1058 | + | 
 | 1059 | +    def test_validate_and_refresh_session_refresh_path(self):  | 
 | 1060 | +        auth = Auth(  | 
 | 1061 | +            self.dummy_project_id,  | 
 | 1062 | +            self.public_key_dict,  | 
 | 1063 | +            http_client=self.make_http_client(),  | 
 | 1064 | +        )  | 
 | 1065 | +        # Force validate_session to fail  | 
 | 1066 | +        with patch.object(  | 
 | 1067 | +            Auth,  | 
 | 1068 | +            "validate_session",  | 
 | 1069 | +            side_effect=AuthException(400, ERROR_TYPE_SERVER_ERROR, "e"),  | 
 | 1070 | +        ):  | 
 | 1071 | +            # Stub refresh network  | 
 | 1072 | +            with patch("requests.post") as mock_post:  | 
 | 1073 | +                net_resp = mock.Mock()  | 
 | 1074 | +                net_resp.ok = True  | 
 | 1075 | +                net_resp.cookies = {"DSR": "cookie"}  | 
 | 1076 | +                auth._validate_token = lambda token, audience=None: {  | 
 | 1077 | +                    "iss": "P1",  | 
 | 1078 | +                    "sub": "u1",  | 
 | 1079 | +                }  | 
 | 1080 | +                net_resp.json.return_value = {"sessionJwt": "s", "refreshJwt": "r"}  | 
 | 1081 | +                mock_post.return_value = net_resp  | 
 | 1082 | +                out = auth.validate_and_refresh_session("bad", refresh_token="r0")  | 
 | 1083 | +                self.assertEqual(out["projectId"], "P1")  | 
 | 1084 | + | 
 | 1085 | +    def test_validate_token_public_key_not_found(self):  | 
 | 1086 | +        auth = Auth(  | 
 | 1087 | +            self.dummy_project_id,  | 
 | 1088 | +            None,  | 
 | 1089 | +            http_client=self.make_http_client(),  | 
 | 1090 | +        )  | 
 | 1091 | +        # ensure public keys empty and fetching sets nothing  | 
 | 1092 | +        auth.public_keys = {}  | 
 | 1093 | +        with patch.object(  | 
 | 1094 | +            Auth,  | 
 | 1095 | +            "_fetch_public_keys",  | 
 | 1096 | +            side_effect=lambda self=auth: setattr(auth, "public_keys", {}),  | 
 | 1097 | +        ):  | 
 | 1098 | +            with patch("descope.auth.jwt.get_unverified_header") as mock_hdr:  | 
 | 1099 | +                mock_hdr.return_value = {"alg": "ES384", "kid": "unknown"}  | 
 | 1100 | +                with self.assertRaises(AuthException) as cm:  | 
 | 1101 | +                    auth._validate_token("any")  | 
 | 1102 | +                self.assertIn("public key not found", str(cm.exception).lower())  | 
 | 1103 | + | 
 | 1104 | +    def test_validate_token_decode_time_errors(self):  | 
 | 1105 | +        auth = Auth(  | 
 | 1106 | +            self.dummy_project_id,  | 
 | 1107 | +            None,  | 
 | 1108 | +            http_client=self.make_http_client(),  | 
 | 1109 | +        )  | 
 | 1110 | +        # Prepare a fake key entry and matching header  | 
 | 1111 | +        auth.public_keys = {"kid": (SimpleNamespace(key="k"), "ES384")}  | 
 | 1112 | +        with patch("descope.auth.jwt.get_unverified_header") as mock_hdr, patch(  | 
 | 1113 | +            "descope.auth.jwt.decode"  | 
 | 1114 | +        ) as mock_dec:  | 
 | 1115 | +            mock_hdr.return_value = {"alg": "ES384", "kid": "kid"}  | 
 | 1116 | +            from jwt import ImmatureSignatureError  | 
 | 1117 | + | 
 | 1118 | +            mock_dec.side_effect = ImmatureSignatureError("early")  | 
 | 1119 | +            with self.assertRaises(AuthException) as cm:  | 
 | 1120 | +                auth._validate_token("tok")  | 
 | 1121 | +            self.assertEqual(cm.exception.status_code, 400)  | 
 | 1122 | + | 
 | 1123 | +    def test_validate_token_success(self):  | 
 | 1124 | +        auth = Auth(  | 
 | 1125 | +            self.dummy_project_id,  | 
 | 1126 | +            None,  | 
 | 1127 | +            http_client=self.make_http_client(),  | 
 | 1128 | +        )  | 
 | 1129 | +        auth.public_keys = {"kid": (SimpleNamespace(key="k"), "ES384")}  | 
 | 1130 | +        with patch("descope.auth.jwt.get_unverified_header") as mock_hdr, patch(  | 
 | 1131 | +            "descope.auth.jwt.decode"  | 
 | 1132 | +        ) as mock_dec:  | 
 | 1133 | +            mock_hdr.return_value = {"alg": "ES384", "kid": "kid"}  | 
 | 1134 | +            mock_dec.return_value = {"sub": "u"}  | 
 | 1135 | +            out = auth._validate_token("tok")  | 
 | 1136 | +            self.assertEqual(out["jwt"], "tok")  | 
 | 1137 | + | 
808 | 1138 | 
 
  | 
809 | 1139 | if __name__ == "__main__":  | 
810 | 1140 |     unittest.main()  | 
0 commit comments