Skip to content
This repository was archived by the owner on Jan 13, 2021. It is now read-only.

Commit a4f185e

Browse files
committed
Merge branch 'development' of github.com:Lukasa/hyper into development
2 parents 9766ad0 + 806aff7 commit a4f185e

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

Diff for: hyper/common/exceptions.py

+7
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,10 @@ def __init__(self, negotiated, sock):
6464
super(HTTPUpgrade, self).__init__()
6565
self.negotiated = negotiated
6666
self.sock = sock
67+
68+
69+
class MissingCertFile(Exception):
70+
"""
71+
The certificate file could not be found.
72+
"""
73+
pass

Diff for: hyper/tls.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
Contains the TLS/SSL logic for use in hyper.
77
"""
88
import os.path as path
9-
9+
from .common.exceptions import MissingCertFile
1010
from .compat import ignore_missing, ssl
1111

1212

@@ -29,14 +29,17 @@ def wrap_socket(sock, server_hostname, ssl_context=None, force_proto=None):
2929
A vastly simplified SSL wrapping function. We'll probably extend this to
3030
do more things later.
3131
"""
32-
global _context
3332

34-
# create the singleton SSLContext we use
35-
if _context is None: # pragma: no cover
36-
_context = init_context()
33+
global _context
3734

38-
# if an SSLContext is provided then use it instead of default context
39-
_ssl_context = ssl_context or _context
35+
if ssl_context:
36+
# if an SSLContext is provided then use it instead of default context
37+
_ssl_context = ssl_context
38+
else:
39+
# create the singleton SSLContext we use
40+
if _context is None: # pragma: no cover
41+
_context = init_context()
42+
_ssl_context = _context
4043

4144
# the spec requires SNI support
4245
ssl_sock = _ssl_context.wrap_socket(sock, server_hostname=server_hostname)
@@ -94,9 +97,17 @@ def init_context(cert_path=None, cert=None, cert_password=None):
9497
encrypted and no password is needed.
9598
:returns: An ``SSLContext`` correctly set up for HTTP/2.
9699
"""
100+
cafile = cert_path or cert_loc
101+
if not cafile or not path.exists(cafile):
102+
err_msg = ("No certificate found at " + str(cafile) + ". Either " +
103+
"ensure the default cert.pem file is included in the " +
104+
"distribution or provide a custom certificate when " +
105+
"creating the connection.")
106+
raise MissingCertFile(err_msg)
107+
97108
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
98109
context.set_default_verify_paths()
99-
context.load_verify_locations(cafile=cert_path or cert_loc)
110+
context.load_verify_locations(cafile=cafile)
100111
context.verify_mode = ssl.CERT_REQUIRED
101112
context.check_hostname = True
102113

Diff for: test/test_SSLContext.py

+15
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CLIENT_CERT_FILE = os.path.join(TEST_CERTS_DIR, 'client.crt')
1515
CLIENT_KEY_FILE = os.path.join(TEST_CERTS_DIR, 'client.key')
1616
CLIENT_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'nopassword.pem')
17+
MISSING_PEM_FILE = os.path.join(TEST_CERTS_DIR, 'missing.pem')
1718

1819

1920
class TestSSLContext(object):
@@ -60,3 +61,17 @@ def test_client_certificates(self):
6061
cert=(CLIENT_CERT_FILE, CLIENT_KEY_FILE),
6162
cert_password=b'abc123')
6263
hyper.tls.init_context(cert=CLIENT_PEM_FILE)
64+
65+
def test_missing_certs(self):
66+
succeeded = False
67+
threw_expected_exception = False
68+
try:
69+
hyper.tls.init_context(MISSING_PEM_FILE)
70+
succeeded = True
71+
except hyper.common.exceptions.MissingCertFile:
72+
threw_expected_exception = True
73+
except:
74+
pass
75+
76+
assert not succeeded
77+
assert threw_expected_exception

0 commit comments

Comments
 (0)