11import logging
2- from typing import Annotated , Callable , Optional , TypeVar
2+ import typing_extensions
3+ from typing import Annotated , Callable , Optional , TypeVar , Union
34
45from .address_helper import validate_address
56from .consumer import Consumer
1011from .qpid .proton ._handlers import MessagingHandler
1112from .qpid .proton ._transport import SSLDomain
1213from .qpid .proton .utils import BlockingConnection
13- from .ssl_configuration import SslConfigurationContext
14+ from .ssl_configuration import PosixSslConfigurationContext , WinSslConfigurationContext , LocalMachineStore , \
15+ CurrentUserStore , PKCS12Store , Unambiguous , FriendlyName
1416
1517logger = logging .getLogger (__name__ )
1618
@@ -34,7 +36,7 @@ def __init__(
3436 uri : Optional [str ] = None ,
3537 # multi-node mode
3638 uris : Optional [list [str ]] = None ,
37- ssl_context : Optional [ SslConfigurationContext ] = None ,
39+ ssl_context : Union [ PosixSslConfigurationContext , WinSslConfigurationContext , None ] = None ,
3840 on_disconnection_handler : Optional [CB ] = None , # type: ignore
3941 ):
4042 """
@@ -60,7 +62,7 @@ def __init__(
6062 self ._conn : BlockingConnection
6163 self ._management : Management
6264 self ._on_disconnection_handler = on_disconnection_handler
63- self ._conf_ssl_context : Optional [ SslConfigurationContext ] = ssl_context
65+ self ._conf_ssl_context : Union [ PosixSslConfigurationContext , WinSslConfigurationContext , None ] = ssl_context
6466 self ._ssl_domain = None
6567 self ._connections = [] # type: ignore
6668 self ._index : int = - 1
@@ -80,17 +82,41 @@ def dial(self) -> None:
8082 logger .debug ("Enabling SSL" )
8183
8284 self ._ssl_domain = SSLDomain (SSLDomain .MODE_CLIENT )
83- if self ._ssl_domain is not None :
84- self ._ssl_domain .set_trusted_ca_db (self ._conf_ssl_context .ca_cert )
85+
86+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
87+ ca_cert = self ._conf_ssl_context .ca_cert
88+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
89+ ca_cert = self ._win_store_to_cert (self ._conf_ssl_context .ca_store )
90+ else :
91+ typing_extensions .assert_never (self ._conf_ssl_context )
92+ self ._ssl_domain .set_trusted_ca_db (ca_cert )
93+
8594 # for mutual authentication
8695 if self ._conf_ssl_context .client_cert is not None :
8796 logger .debug ("Enabling mutual authentication as well" )
88- if self ._ssl_domain is not None :
89- self ._ssl_domain .set_credentials (
90- self ._conf_ssl_context .client_cert .client_cert ,
91- self ._conf_ssl_context .client_cert .client_key ,
92- self ._conf_ssl_context .client_cert .password ,
93- )
97+
98+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
99+ client_cert = self ._conf_ssl_context .client_cert
100+ client_key = self ._conf_ssl_context .client_cert .client_key
101+ password = self ._conf_ssl_context .client_cert .password
102+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
103+ client_cert = self ._win_store_to_cert (self ._conf_ssl_context .client_cert .store )
104+ disambiguation_method = self ._conf_ssl_context .client_cert .disambiguation_method
105+ if isinstance (disambiguation_method , Unambiguous ):
106+ client_key = None
107+ elif isinstance (disambiguation_method , FriendlyName ):
108+ client_key = disambiguation_method .name
109+ else :
110+ typing_extensions .assert_never (disambiguation_method )
111+ password = self ._conf_ssl_context .client_cert .password
112+ else :
113+ typing_extensions .assert_never (self ._conf_ssl_context )
114+
115+ self ._ssl_domain .set_credentials (
116+ client_cert ,
117+ client_key ,
118+ password ,
119+ )
94120 self ._conn = BlockingConnection (
95121 url = self ._addr ,
96122 urls = self ._addrs ,
@@ -100,6 +126,17 @@ def dial(self) -> None:
100126 self ._open ()
101127 logger .debug ("Connection to the server established" )
102128
129+ def _win_store_to_cert (self , store : Union [LocalMachineStore , CurrentUserStore , PKCS12Store ]) -> str :
130+ if isinstance (store , LocalMachineStore ):
131+ ca_cert = f"lmss:{ store .name } "
132+ elif isinstance (store , CurrentUserStore ):
133+ ca_cert = f"ss:{ store .name } "
134+ elif isinstance (store , PKCS12Store ):
135+ ca_cert = store .path
136+ else :
137+ typing_extensions .assert_never (store )
138+ return ca_cert
139+
103140 def _open (self ) -> None :
104141 self ._management = Management (self ._conn )
105142 self ._management .open ()
0 commit comments