Skip to content

Commit 25c0032

Browse files
author
Dos Moonen
committed
Refactor SslConfigurationContext into platform specific classes
1 parent 884cc8b commit 25c0032

File tree

10 files changed

+154
-55
lines changed

10 files changed

+154
-55
lines changed

examples/getting_started/getting_started.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# type: ignore
22

33

4-
from rabbitmq_amqp_python_client import ( # SSlConfigurationContext,; SslConfigurationContext,; ClientCert,
4+
from rabbitmq_amqp_python_client import (
5+
# PosixSSlConfigurationContext,
6+
# PosixClientCert,
57
AddressHelper,
68
AMQPMessagingHandler,
79
Connection,
@@ -68,7 +70,7 @@ def create_connection(environment: Environment) -> Connection:
6870
# client_key = ".ci/certs/client_key.pem"
6971
# connection = Connection(
7072
# "amqps://guest:guest@localhost:5671/",
71-
# ssl_context=SslConfigurationContext(
73+
# ssl_context=PosixSslConfigurationContext(
7274
# ca_cert=ca_cert_file,
7375
# client_cert=ClientCert(client_cert=client_cert, client_key=client_key),
7476
# ),

examples/streams/example_with_streams.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# type: ignore
22

3-
from rabbitmq_amqp_python_client import ( # SSlConfigurationContext,; SslConfigurationContext,; ClientCert,
3+
from rabbitmq_amqp_python_client import (
4+
# PosixSSlConfigurationContext,
5+
# PosixClientCert,
46
AddressHelper,
57
AMQPMessagingHandler,
68
Connection,
@@ -72,9 +74,9 @@ def create_connection(environment: Environment) -> Connection:
7274
# client_key = ".ci/certs/client_key.pem"
7375
# connection = Connection(
7476
# "amqps://guest:guest@localhost:5671/",
75-
# ssl_context=SslConfigurationContext(
77+
# ssl_context=PosixSslConfigurationContext(
7678
# ca_cert=ca_cert_file,
77-
# client_cert=ClientCert(client_cert=client_cert, client_key=client_key),
79+
# client_cert=PosixClientCert(client_cert=client_cert, client_key=client_key),
7880
# ),
7981
# )
8082
connection.dial()

examples/tls/tls_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# type: ignore
22

33

4-
from rabbitmq_amqp_python_client import ( # SSlConfigurationContext,; SslConfigurationContext,; ClientCert,
4+
from rabbitmq_amqp_python_client import (
55
AddressHelper,
66
AMQPMessagingHandler,
7-
ClientCert,
7+
PosixClientCert,
88
Connection,
99
Environment,
1010
Event,
1111
ExchangeSpecification,
1212
ExchangeToQueueBindingSpecification,
1313
Message,
1414
QuorumQueueSpecification,
15-
SslConfigurationContext,
15+
PosixSslConfigurationContext,
1616
)
1717

1818
messages_to_publish = 100
@@ -80,9 +80,9 @@ def main() -> None:
8080

8181
environment = Environment(
8282
"amqps://guest:guest@localhost:5671/",
83-
ssl_context=SslConfigurationContext(
83+
ssl_context=PosixSslConfigurationContext(
8484
ca_cert=ca_cert_file,
85-
client_cert=ClientCert(client_cert=client_cert, client_key=client_key),
85+
client_cert=PosixClientCert(client_cert=client_cert, client_key=client_key),
8686
),
8787
)
8888

rabbitmq_amqp_python_client/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@
3232
StreamSpecification,
3333
)
3434
from .ssl_configuration import (
35-
ClientCert,
36-
SslConfigurationContext,
35+
PosixClientCert,
36+
PosixSslConfigurationContext,
37+
WinClientCert,
38+
WinSslConfigurationContext,
39+
LocalMachineStore,
40+
CurrentUserStore,
41+
PKCS12Store,
3742
)
3843

3944
try:
@@ -69,8 +74,13 @@
6974
"AMQPMessagingHandler",
7075
"ArgumentOutOfRangeException",
7176
"ValidationCodeException",
72-
"SslConfigurationContext",
73-
"ClientCert",
77+
"PosixSslConfigurationContext",
78+
"WinSslConfigurationContext",
79+
"PosixClientCert",
80+
"WinClientCert",
81+
"LocalMachineStore",
82+
"CurrentUserStore",
83+
"PKCS12Store",
7484
"ConnectionClosed",
7585
"StreamOptions",
7686
"OffsetSpecification",

rabbitmq_amqp_python_client/connection.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
2-
from typing import Annotated, Callable, Optional, TypeVar
2+
import typing_extensions
3+
from typing import Annotated, Callable, Optional, TypeVar, Union
34

45
from .address_helper import validate_address
56
from .consumer import Consumer
@@ -10,7 +11,8 @@
1011
from .qpid.proton._handlers import MessagingHandler
1112
from .qpid.proton._transport import SSLDomain
1213
from .qpid.proton.utils import BlockingConnection
13-
from .ssl_configuration import SslConfigurationContext
14+
from .ssl_configuration import PosixSslConfigurationContext, WinSslConfigurationContext, LocalMachineStore, \
15+
CurrentUserStore, PKCS12Store, Unambiguous, FriendlyName
1416

1517
logger = logging.getLogger(__name__)
1618

@@ -34,7 +36,7 @@ def __init__(
3436
uri: Optional[str] = None,
3537
# multi-node mode
3638
uris: Optional[list[str]] = None,
37-
ssl_context: Optional[SslConfigurationContext] = None,
39+
ssl_context: Union[PosixSslConfigurationContext, WinSslConfigurationContext, None] = None,
3840
on_disconnection_handler: Optional[CB] = None, # type: ignore
3941
):
4042
"""
@@ -60,7 +62,7 @@ def __init__(
6062
self._conn: BlockingConnection
6163
self._management: Management
6264
self._on_disconnection_handler = on_disconnection_handler
63-
self._conf_ssl_context: Optional[SslConfigurationContext] = ssl_context
65+
self._conf_ssl_context: Union[PosixSslConfigurationContext, WinSslConfigurationContext, None] = ssl_context
6466
self._ssl_domain = None
6567
self._connections = [] # type: ignore
6668
self._index: int = -1
@@ -80,17 +82,41 @@ def dial(self) -> None:
8082
logger.debug("Enabling SSL")
8183

8284
self._ssl_domain = SSLDomain(SSLDomain.MODE_CLIENT)
83-
if self._ssl_domain is not None:
84-
self._ssl_domain.set_trusted_ca_db(self._conf_ssl_context.ca_cert)
85+
86+
if isinstance(self._conf_ssl_context, PosixSslConfigurationContext):
87+
ca_cert = self._conf_ssl_context.ca_cert
88+
elif isinstance(self._conf_ssl_context, WinSslConfigurationContext):
89+
ca_cert = self._win_store_to_cert(self._conf_ssl_context.ca_store)
90+
else:
91+
typing_extensions.assert_never(self._conf_ssl_context)
92+
self._ssl_domain.set_trusted_ca_db(ca_cert)
93+
8594
# for mutual authentication
8695
if self._conf_ssl_context.client_cert is not None:
8796
logger.debug("Enabling mutual authentication as well")
88-
if self._ssl_domain is not None:
89-
self._ssl_domain.set_credentials(
90-
self._conf_ssl_context.client_cert.client_cert,
91-
self._conf_ssl_context.client_cert.client_key,
92-
self._conf_ssl_context.client_cert.password,
93-
)
97+
98+
if isinstance(self._conf_ssl_context, PosixSslConfigurationContext):
99+
client_cert = self._conf_ssl_context.client_cert
100+
client_key = self._conf_ssl_context.client_cert.client_key
101+
password = self._conf_ssl_context.client_cert.password
102+
elif isinstance(self._conf_ssl_context, WinSslConfigurationContext):
103+
client_cert = self._win_store_to_cert(self._conf_ssl_context.client_cert.store)
104+
disambiguation_method = self._conf_ssl_context.client_cert.disambiguation_method
105+
if isinstance(disambiguation_method, Unambiguous):
106+
client_key = None
107+
elif isinstance(disambiguation_method, FriendlyName):
108+
client_key = disambiguation_method.name
109+
else:
110+
typing_extensions.assert_never(disambiguation_method)
111+
password = self._conf_ssl_context.client_cert.password
112+
else:
113+
typing_extensions.assert_never(self._conf_ssl_context)
114+
115+
self._ssl_domain.set_credentials(
116+
client_cert,
117+
client_key,
118+
password,
119+
)
94120
self._conn = BlockingConnection(
95121
url=self._addr,
96122
urls=self._addrs,
@@ -100,6 +126,17 @@ def dial(self) -> None:
100126
self._open()
101127
logger.debug("Connection to the server established")
102128

129+
def _win_store_to_cert(self, store: Union[LocalMachineStore, CurrentUserStore, PKCS12Store]) -> str:
130+
if isinstance(store, LocalMachineStore):
131+
ca_cert = f"lmss:{store.name}"
132+
elif isinstance(store, CurrentUserStore):
133+
ca_cert = f"ss:{store.name}"
134+
elif isinstance(store, PKCS12Store):
135+
ca_cert = store.path
136+
else:
137+
typing_extensions.assert_never(store)
138+
return ca_cert
139+
103140
def _open(self) -> None:
104141
self._management = Management(self._conn)
105142
self._management.open()

rabbitmq_amqp_python_client/environment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# For the moment this is just a Connection pooler to keep compatibility with other clients
22
import logging
3-
from typing import Annotated, Callable, Optional, TypeVar
3+
from typing import Annotated, Callable, Optional, TypeVar, Union
44

55
from .connection import Connection
6-
from .ssl_configuration import SslConfigurationContext
6+
from .ssl_configuration import PosixSslConfigurationContext, WinSslConfigurationContext
77

88
logger = logging.getLogger(__name__)
99

@@ -28,7 +28,7 @@ def __init__(
2828
uri: Optional[str] = None,
2929
# multi-node mode
3030
uris: Optional[list[str]] = None,
31-
ssl_context: Optional[SslConfigurationContext] = None,
31+
ssl_context: Union[PosixSslConfigurationContext, WinSslConfigurationContext, None] = None,
3232
on_disconnection_handler: Optional[CB] = None, # type: ignore
3333
):
3434
"""

rabbitmq_amqp_python_client/qpid/proton/_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def __init__(self, mode: int) -> None:
826826
def _check(self, err: int) -> int:
827827
if err < 0:
828828
exc = EXCEPTIONS.get(err, SSLException)
829-
raise exc("SSL failure.")
829+
raise exc("SSL failure.", err)
830830
else:
831831
return err
832832

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,53 @@
11
from dataclasses import dataclass
2-
from typing import Optional
2+
from typing import Optional, Union
33

44

55
@dataclass
6-
class ClientCert:
6+
class PosixClientCert:
77
client_cert: str
88
client_key: str
99
password: Optional[str] = None
1010

11+
@dataclass
12+
class Unambiguous:
13+
"""Use the only certificate in the store."""
14+
...
15+
16+
@dataclass
17+
class FriendlyName:
18+
"""Use the first certificate with a matching friendly name."""
19+
name: str
20+
21+
22+
@dataclass
23+
class LocalMachineStore:
24+
name: str
25+
26+
27+
@dataclass
28+
class CurrentUserStore:
29+
name: str
30+
1131

1232
@dataclass
13-
class SslConfigurationContext:
33+
class PKCS12Store:
34+
path: str
35+
36+
37+
@dataclass
38+
class WinClientCert:
39+
store: Union[LocalMachineStore, CurrentUserStore, PKCS12Store]
40+
disambiguation_method: Union[Unambiguous, FriendlyName]
41+
password: Optional[str] = None
42+
43+
44+
@dataclass
45+
class PosixSslConfigurationContext:
1446
ca_cert: str
15-
client_cert: Optional[ClientCert] = None
47+
client_cert: Union[PosixClientCert, WinClientCert, None] = None
48+
49+
50+
@dataclass
51+
class WinSslConfigurationContext:
52+
ca_store: Union[LocalMachineStore, CurrentUserStore, PKCS12Store]
53+
client_cert: Union[PosixClientCert, WinClientCert, None] = None

tests/conftest.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
import os
2+
import sys
13
from typing import Optional
24

35
import pytest
46

57
from rabbitmq_amqp_python_client import (
68
AddressHelper,
79
AMQPMessagingHandler,
8-
ClientCert,
10+
PosixClientCert,
11+
WinClientCert,
912
Environment,
1013
Event,
11-
SslConfigurationContext,
14+
PosixSslConfigurationContext,
15+
WinSslConfigurationContext,
1216
symbol,
17+
PKCS12Store,
1318
)
19+
from rabbitmq_amqp_python_client.ssl_configuration import FriendlyName
1420

21+
os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
1522

1623
@pytest.fixture()
1724
def environment(pytestconfig):
@@ -34,19 +41,31 @@ def connection(pytestconfig):
3441
finally:
3542
environment.close()
3643

44+
@pytest.fixture()
45+
def ssl_context(pytestconfig):
46+
if sys.platform == "win32":
47+
return WinSslConfigurationContext(
48+
ca_store=PKCS12Store(path=".ci/certs/ca.p12"),
49+
client_cert=WinClientCert(
50+
store=PKCS12Store(path=".ci/certs/client.p12"),
51+
disambiguation_method=FriendlyName(name="gsantomagg6LVDM.vmware.com"),
52+
),
53+
)
54+
else:
55+
return PosixSslConfigurationContext(
56+
ca_cert=".ci/certs/ca_certificate.pem",
57+
client_cert=PosixClientCert(
58+
client_cert=".ci/certs/client_certificate.pem",
59+
client_key=".ci/certs/client_key.pem",
60+
),
61+
)
3762

3863
@pytest.fixture()
39-
def connection_ssl(pytestconfig):
40-
ca_cert_file = ".ci/certs/ca_certificate.pem"
41-
client_cert = ".ci/certs/client_certificate.pem"
42-
client_key = ".ci/certs/client_key.pem"
64+
def connection_ssl(pytestconfig, ssl_context):
4365

4466
environment = Environment(
4567
"amqps://guest:guest@localhost:5671/",
46-
ssl_context=SslConfigurationContext(
47-
ca_cert=ca_cert_file,
48-
client_cert=ClientCert(client_cert=client_cert, client_key=client_key),
49-
),
68+
ssl_context=ssl_context,
5069
)
5170
connection = environment.connection()
5271
connection.dial()

tests/test_connection.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import time
22

33
from rabbitmq_amqp_python_client import (
4-
ClientCert,
54
ConnectionClosed,
65
Environment,
7-
SslConfigurationContext,
86
StreamSpecification,
97
)
108

@@ -31,17 +29,10 @@ def test_environment_context_manager() -> None:
3129
connection.dial()
3230

3331

34-
def test_connection_ssl() -> None:
35-
ca_cert_file = ".ci/certs/ca_certificate.pem"
36-
client_cert = ".ci/certs/client_certificate.pem"
37-
client_key = ".ci/certs/client_key.pem"
38-
32+
def test_connection_ssl(ssl_context) -> None:
3933
environment = Environment(
4034
"amqps://guest:guest@localhost:5671/",
41-
ssl_context=SslConfigurationContext(
42-
ca_cert=ca_cert_file,
43-
client_cert=ClientCert(client_cert=client_cert, client_key=client_key),
44-
),
35+
ssl_context=ssl_context,
4536
)
4637

4738
connection = environment.connection()

0 commit comments

Comments
 (0)