11import logging
2- from typing import Annotated , Callable , Optional , TypeVar
2+ from typing import (
3+ Annotated ,
4+ Callable ,
5+ Optional ,
6+ TypeVar ,
7+ Union ,
8+ )
9+
10+ import typing_extensions
311
412from .address_helper import validate_address
513from .consumer import Consumer
1018from .qpid .proton ._handlers import MessagingHandler
1119from .qpid .proton ._transport import SSLDomain
1220from .qpid .proton .utils import BlockingConnection
13- from .ssl_configuration import SslConfigurationContext
21+ from .ssl_configuration import (
22+ CurrentUserStore ,
23+ FriendlyName ,
24+ LocalMachineStore ,
25+ PKCS12Store ,
26+ PosixSslConfigurationContext ,
27+ Unambiguous ,
28+ WinSslConfigurationContext ,
29+ )
1430
1531logger = logging .getLogger (__name__ )
1632
@@ -34,7 +50,9 @@ def __init__(
3450 uri : Optional [str ] = None ,
3551 # multi-node mode
3652 uris : Optional [list [str ]] = None ,
37- ssl_context : Optional [SslConfigurationContext ] = None ,
53+ ssl_context : Union [
54+ PosixSslConfigurationContext , WinSslConfigurationContext , None
55+ ] = None ,
3856 on_disconnection_handler : Optional [CB ] = None , # type: ignore
3957 ):
4058 """
@@ -60,7 +78,9 @@ def __init__(
6078 self ._conn : BlockingConnection
6179 self ._management : Management
6280 self ._on_disconnection_handler = on_disconnection_handler
63- self ._conf_ssl_context : Optional [SslConfigurationContext ] = ssl_context
81+ self ._conf_ssl_context : Union [
82+ PosixSslConfigurationContext , WinSslConfigurationContext , None
83+ ] = ssl_context
6484 self ._ssl_domain = None
6585 self ._connections = [] # type: ignore
6686 self ._index : int = - 1
@@ -80,17 +100,45 @@ def dial(self) -> None:
80100 logger .debug ("Enabling SSL" )
81101
82102 self ._ssl_domain = SSLDomain (SSLDomain .MODE_CLIENT )
83- if self ._ssl_domain is not None :
84- self ._ssl_domain .set_trusted_ca_db (self ._conf_ssl_context .ca_cert )
103+
104+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
105+ ca_cert = self ._conf_ssl_context .ca_cert
106+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
107+ ca_cert = self ._win_store_to_cert (self ._conf_ssl_context .ca_store )
108+ else :
109+ typing_extensions .assert_never (self ._conf_ssl_context )
110+ self ._ssl_domain .set_trusted_ca_db (ca_cert )
111+
85112 # for mutual authentication
86113 if self ._conf_ssl_context .client_cert is not None :
87114 logger .debug ("Enabling mutual authentication as well" )
88- if self ._ssl_domain is not None :
89- self ._ssl_domain .set_credentials (
90- self ._conf_ssl_context .client_cert .client_cert ,
91- self ._conf_ssl_context .client_cert .client_key ,
92- self ._conf_ssl_context .client_cert .password ,
115+
116+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
117+ client_cert = self ._conf_ssl_context .client_cert
118+ client_key = self ._conf_ssl_context .client_cert .client_key
119+ password = self ._conf_ssl_context .client_cert .password
120+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
121+ client_cert = self ._win_store_to_cert (
122+ self ._conf_ssl_context .client_cert .store
93123 )
124+ disambiguation_method = (
125+ self ._conf_ssl_context .client_cert .disambiguation_method
126+ )
127+ if isinstance (disambiguation_method , Unambiguous ):
128+ client_key = None
129+ elif isinstance (disambiguation_method , FriendlyName ):
130+ client_key = disambiguation_method .name
131+ else :
132+ typing_extensions .assert_never (disambiguation_method )
133+ password = self ._conf_ssl_context .client_cert .password
134+ else :
135+ typing_extensions .assert_never (self ._conf_ssl_context )
136+
137+ self ._ssl_domain .set_credentials (
138+ client_cert ,
139+ client_key ,
140+ password ,
141+ )
94142 self ._conn = BlockingConnection (
95143 url = self ._addr ,
96144 urls = self ._addrs ,
@@ -100,6 +148,19 @@ def dial(self) -> None:
100148 self ._open ()
101149 logger .debug ("Connection to the server established" )
102150
151+ def _win_store_to_cert (
152+ self , store : Union [LocalMachineStore , CurrentUserStore , PKCS12Store ]
153+ ) -> str :
154+ if isinstance (store , LocalMachineStore ):
155+ ca_cert = f"lmss:{ store .name } "
156+ elif isinstance (store , CurrentUserStore ):
157+ ca_cert = f"ss:{ store .name } "
158+ elif isinstance (store , PKCS12Store ):
159+ ca_cert = store .path
160+ else :
161+ typing_extensions .assert_never (store )
162+ return ca_cert
163+
103164 def _open (self ) -> None :
104165 self ._management = Management (self ._conn )
105166 self ._management .open ()
0 commit comments