python36/switch-to-PROTOCOL_TLS_CLIENT.patch
Matěj Cepl 023803647f
Update with all development for OpenSSL 3.1.4.
Especially:
- Add bpo31429_define-TLS-cipher-suite.patch defines TLS cipher
  suite on build time (add --with-ssl-default-suites option to
  ./configure; from gh#python/cpython!3532).
2024-01-27 18:58:24 +01:00

1201 lines
60 KiB
Diff

---
Lib/test/test_ssl.py | 510 +++++++++++++++++++++++++++------------------------
1 file changed, 275 insertions(+), 235 deletions(-)
Index: Python-3.6.15/Lib/test/test_ssl.py
===================================================================
--- Python-3.6.15.orig/Lib/test/test_ssl.py
+++ Python-3.6.15/Lib/test/test_ssl.py
@@ -22,6 +22,7 @@ import weakref
import platform
import re
import functools
+import warnings
import sysconfig
try:
import ctypes
@@ -73,6 +74,7 @@ CRLFILE = data_file("revocation.crl")
# Two keys and certs signed by the same CA (for SNI tests)
SIGNED_CERTFILE = data_file("keycert3.pem")
+SIGNED_CERTFILE_HOSTNAME = 'localhost'
SIGNED_CERTFILE2 = data_file("keycert4.pem")
# Same certificate as pycacert.pem, but without extra text in file
SIGNING_CA = data_file("capath", "ceff1710.0")
@@ -205,15 +207,17 @@ def skip_if_openssl_cnf_minprotocol_gt_t
return f
-
needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
cert_reqs=ssl.CERT_NONE, ca_certs=None,
ciphers=None, certfile=None, keyfile=None,
+ check_hostname=None,
**kwargs):
context = ssl.SSLContext(ssl_version)
+ if check_hostname is not None:
+ context.check_hostname = check_hostname
if cert_reqs is not None:
context.verify_mode = cert_reqs
if ca_certs is not None:
@@ -224,6 +228,30 @@ def test_wrap_socket(sock, ssl_version=s
context.set_ciphers(ciphers)
return context.wrap_socket(sock, **kwargs)
+def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True):
+ """Create context
+
+ client_context, server_context, hostname = testing_context()
+ """
+ if server_cert == SIGNED_CERTFILE:
+ hostname = SIGNED_CERTFILE_HOSTNAME
+ elif server_cert == SIGNED_CERTFILE2:
+ hostname = SIGNED_CERTFILE2_HOSTNAME
+ elif server_cert == NOSANFILE:
+ hostname = NOSAN_HOSTNAME
+ else:
+ raise ValueError(server_cert)
+
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_context.load_verify_locations(SIGNING_CA)
+
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ server_context.load_cert_chain(server_cert)
+ if server_chain:
+ server_context.load_verify_locations(SIGNING_CA)
+
+ return client_context, server_context, hostname
+
class BasicSocketTests(unittest.TestCase):
def test_constants(self):
@@ -466,7 +494,7 @@ class BasicSocketTests(unittest.TestCase
self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
(s, t, hex(n)))
else:
- self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
+ self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, patch)),
(s, t, hex(n)))
@support.cpython_only
@@ -545,7 +573,8 @@ class BasicSocketTests(unittest.TestCase
with self.assertRaises(ssl.SSLError):
test_wrap_socket(sock,
certfile=certfile,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ check_hostname=False,
+ ssl_version=ssl.PROTOCOL_TLS_CLIENT)
def test_empty_cert(self):
"""Wrapping with an empty cert file"""
@@ -974,7 +1003,7 @@ class ContextTests(unittest.TestCase):
self.assertEqual(ctx.protocol, proto)
def test_ciphers(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.set_ciphers("ALL")
ctx.set_ciphers("DEFAULT")
with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
@@ -995,7 +1024,7 @@ class ContextTests(unittest.TestCase):
@unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
def test_get_ciphers(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.set_ciphers('AESGCM')
names = set(d['name'] for d in ctx.get_ciphers())
self.assertIn('AES256-GCM-SHA384', names)
@@ -1027,6 +1056,7 @@ class ContextTests(unittest.TestCase):
def test_verify_mode(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.check_hostname = False
# Default value
self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
ctx.verify_mode = ssl.CERT_OPTIONAL
@@ -1043,7 +1073,7 @@ class ContextTests(unittest.TestCase):
@unittest.skipUnless(have_verify_flags(),
"verify_flags need OpenSSL > 0.9.8")
def test_verify_flags(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# default value
tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
@@ -1061,7 +1091,7 @@ class ContextTests(unittest.TestCase):
ctx.verify_flags = None
def test_load_cert_chain(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# Combined key and cert in a single file
ctx.load_cert_chain(CERTFILE, keyfile=None)
ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
@@ -1074,7 +1104,7 @@ class ContextTests(unittest.TestCase):
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
ctx.load_cert_chain(EMPTYCERT)
# Separate key and cert
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(ONLYCERT, ONLYKEY)
ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
@@ -1085,7 +1115,7 @@ class ContextTests(unittest.TestCase):
with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
# Mismatching key and cert
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
# Password protected key and cert
@@ -1144,7 +1174,7 @@ class ContextTests(unittest.TestCase):
ctx.load_cert_chain(CERTFILE, password=getpass_exception)
def test_load_verify_locations(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_verify_locations(CERTFILE)
ctx.load_verify_locations(cafile=CERTFILE, capath=None)
ctx.load_verify_locations(BYTES_CERTFILE)
@@ -1172,7 +1202,7 @@ class ContextTests(unittest.TestCase):
neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
# test PEM
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
ctx.load_verify_locations(cadata=cacert_pem)
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
@@ -1183,20 +1213,20 @@ class ContextTests(unittest.TestCase):
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
# combined
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
combined = "\n".join((cacert_pem, neuronio_pem))
ctx.load_verify_locations(cadata=combined)
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
# with junk around the certs
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
combined = ["head", cacert_pem, "other", neuronio_pem, "again",
neuronio_pem, "tail"]
ctx.load_verify_locations(cadata="\n".join(combined))
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
# test DER
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_verify_locations(cadata=cacert_der)
ctx.load_verify_locations(cadata=neuronio_der)
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
@@ -1205,13 +1235,13 @@ class ContextTests(unittest.TestCase):
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
# combined
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
combined = b"".join((cacert_der, neuronio_der))
ctx.load_verify_locations(cadata=combined)
self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
# error cases
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
with self.assertRaisesRegex(
@@ -1229,7 +1259,7 @@ class ContextTests(unittest.TestCase):
def test_load_dh_params(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_dh_params(DHFILE)
if os.name != 'nt':
ctx.load_dh_params(BYTES_DHFILE)
@@ -1262,12 +1292,12 @@ class ContextTests(unittest.TestCase):
def test_set_default_verify_paths(self):
# There's not much we can do to test that it acts as expected,
# so just check it doesn't crash or raise an exception.
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.set_default_verify_paths()
@unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
def test_set_ecdh_curve(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.set_ecdh_curve("prime256v1")
ctx.set_ecdh_curve(b"prime256v1")
self.assertRaises(TypeError, ctx.set_ecdh_curve)
@@ -1277,7 +1307,8 @@ class ContextTests(unittest.TestCase):
@needs_sni
def test_sni_callback(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ ctx.check_hostname = False
# set_servername_callback expects a callable, or None
self.assertRaises(TypeError, ctx.set_servername_callback)
@@ -1294,7 +1325,7 @@ class ContextTests(unittest.TestCase):
def test_sni_callback_refcycle(self):
# Reference cycles through the servername callback are detected
# and cleared.
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
def dummycallback(sock, servername, ctx, cycle=ctx):
pass
ctx.set_servername_callback(dummycallback)
@@ -1304,7 +1335,7 @@ class ContextTests(unittest.TestCase):
self.assertIs(wr(), None)
def test_cert_store_stats(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
{'x509_ca': 0, 'crl': 0, 'x509': 0})
ctx.load_cert_chain(CERTFILE)
@@ -1318,7 +1349,7 @@ class ContextTests(unittest.TestCase):
{'x509_ca': 1, 'crl': 0, 'x509': 2})
def test_get_ca_certs(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.get_ca_certs(), [])
# CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
ctx.load_verify_locations(CERTFILE)
@@ -1346,24 +1377,24 @@ class ContextTests(unittest.TestCase):
self.assertEqual(ctx.get_ca_certs(True), [der])
def test_load_default_certs(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_default_certs()
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
ctx.load_default_certs()
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertRaises(TypeError, ctx.load_default_certs, None)
self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
@unittest.skipIf(sys.platform == "win32", "not-Windows specific")
@unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
def test_load_default_certs_env(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
with support.EnvironmentVarGuard() as env:
env["SSL_CERT_DIR"] = CAPATH
env["SSL_CERT_FILE"] = CERTFILE
@@ -1372,11 +1403,11 @@ class ContextTests(unittest.TestCase):
@unittest.skipUnless(sys.platform == "win32", "Windows specific")
def test_load_default_certs_env_windows(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_default_certs()
stats = ctx.cert_store_stats()
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
with support.EnvironmentVarGuard() as env:
env["SSL_CERT_DIR"] = CAPATH
env["SSL_CERT_FILE"] = CERTFILE
@@ -1423,20 +1454,20 @@ class ContextTests(unittest.TestCase):
def test__create_stdlib_context(self):
ctx = ssl._create_stdlib_context()
- self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
+ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
self.assertFalse(ctx.check_hostname)
self._assert_context_options(ctx)
- ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
- self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
- self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
+ ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT)
+ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
+ self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
self._assert_context_options(ctx)
- ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
+ ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT,
cert_reqs=ssl.CERT_REQUIRED,
check_hostname=True)
- self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
+ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
self.assertTrue(ctx.check_hostname)
self._assert_context_options(ctx)
@@ -1447,7 +1478,7 @@ class ContextTests(unittest.TestCase):
self._assert_context_options(ctx)
def test_check_hostname(self):
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
self.assertFalse(ctx.check_hostname)
# Requires CERT_REQUIRED or CERT_OPTIONAL
@@ -1494,7 +1525,7 @@ class SSLErrorTests(unittest.TestCase):
def test_lib_reason(self):
# Test the library and reason attributes
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
with self.assertRaises(ssl.SSLError) as cm:
ctx.load_dh_params(CERTFILE)
self.assertEqual(cm.exception.library, 'PEM')
@@ -1505,7 +1536,9 @@ class SSLErrorTests(unittest.TestCase):
def test_subclass(self):
# Check that the appropriate SSLError subclass is raised
# (this only tests one of them)
- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ctx.check_hostname = False
+ ctx.verify_mode = ssl.CERT_NONE
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
s.listen()
@@ -1657,7 +1690,9 @@ class SimpleBackgroundTests(unittest.Tes
def test_connect_with_context(self):
# Same as test_connect, but with a separately created context
- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ctx.check_hostname = False
+ ctx.verify_mode = ssl.CERT_NONE
with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
s.connect(self.server_addr)
self.assertEqual({}, s.getpeercert())
@@ -1805,10 +1840,12 @@ class SimpleBackgroundTests(unittest.Tes
@unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), "needs TLS 1.2")
def test_context_setget(self):
# Check that the context of a connected socket can be replaced.
- ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
- ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ctx1.load_verify_locations(capath=CAPATH)
+ ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ctx2.load_verify_locations(capath=CAPATH)
s = socket.socket(socket.AF_INET)
- with ctx1.wrap_socket(s) as ss:
+ with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
ss.connect(self.server_addr)
self.assertIs(ss.context, ctx1)
self.assertIs(ss._sslobj.context, ctx1)
@@ -2136,24 +2173,26 @@ if _have_threads:
certreqs=None, cacerts=None,
chatty=True, connectionchatty=False, starttls_server=False,
npn_protocols=None, alpn_protocols=None,
- ciphers=None, context=None):
+ ciphers=None, context=None, check_hostname=None):
if context:
self.context = context
else:
self.context = ssl.SSLContext(ssl_version
if ssl_version is not None
else ssl.PROTOCOL_TLS)
+ if check_hostname is not None:
+ self.context.check_hostname = check_hostname
self.context.verify_mode = (certreqs if certreqs is not None
else ssl.CERT_NONE)
- if cacerts:
+ if cacerts is not None:
self.context.load_verify_locations(cacerts)
- if certificate:
+ if certificate is not None:
self.context.load_cert_chain(certificate)
- if npn_protocols:
+ if npn_protocols is not None:
self.context.set_npn_protocols(npn_protocols)
- if alpn_protocols:
+ if alpn_protocols is not None:
self.context.set_alpn_protocols(alpn_protocols)
- if ciphers:
+ if ciphers is not None:
self.context.set_ciphers(ciphers)
self.chatty = chatty
self.connectionchatty = connectionchatty
@@ -2332,7 +2371,7 @@ if _have_threads:
def server_params_test(client_context, server_context, indata=b"FOO\n",
chatty=True, connectionchatty=False, sni_name=None,
- session=None):
+ session=None, check_hostname=None):
"""
Launch a server, connect a client to it and try various reads
and writes.
@@ -2340,7 +2379,8 @@ if _have_threads:
stats = {}
server = ThreadedEchoServer(context=server_context,
chatty=chatty,
- connectionchatty=False)
+ connectionchatty=False,
+ check_hostname=check_hostname)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=sni_name, session=session) as s:
@@ -2402,18 +2442,24 @@ if _have_threads:
(ssl.get_protocol_name(client_protocol),
ssl.get_protocol_name(server_protocol),
certtype))
- client_context = ssl.SSLContext(client_protocol)
- client_context.options |= client_options
- server_context = ssl.SSLContext(server_protocol)
- server_context.options |= server_options
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ client_context = ssl.SSLContext(client_protocol)
+ client_context.options |= client_options
+ server_context = ssl.SSLContext(server_protocol)
+ server_context.options |= server_options
+ if certsreqs == ssl.CERT_NONE:
+ server_context.check_hostname = False
# NOTE: we must enable "ALL" ciphers on the client, otherwise an
# SSLv23 client will send an SSLv3 hello (rather than SSLv2)
# starting from OpenSSL 1.0.0 (see issue #8322).
- if client_context.protocol == ssl.PROTOCOL_SSLv23:
+ if client_context.protocol == ssl.PROTOCOL_TLS:
client_context.set_ciphers("ALL")
for ctx in (client_context, server_context):
+ ctx.check_hostname = False
ctx.verify_mode = certsreqs
ctx.load_cert_chain(CERTFILE)
ctx.load_verify_locations(CERTFILE)
@@ -2447,26 +2493,14 @@ if _have_threads:
"""Basic test of an SSL client connecting to a server"""
if support.verbose:
sys.stdout.write("\n")
- for protocol in PROTOCOLS:
- if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
- continue
- with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
- context = ssl.SSLContext(protocol)
- context.load_cert_chain(CERTFILE)
- server_params_test(context, context,
- chatty=True, connectionchatty=True)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- client_context.load_verify_locations(SIGNING_CA)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- # server_context.load_verify_locations(SIGNING_CA)
- server_context.load_cert_chain(SIGNED_CERTFILE2)
+ client_context, server_context, hostname = testing_context()
with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
server_params_test(client_context=client_context,
server_context=server_context,
chatty=True, connectionchatty=True,
- sni_name='fakehostname')
+ sni_name=hostname)
client_context.check_hostname = False
with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
@@ -2474,9 +2508,8 @@ if _have_threads:
server_params_test(client_context=server_context,
server_context=client_context,
chatty=True, connectionchatty=True,
- sni_name='fakehostname')
- self.assertIn('called a function you should not call',
- str(e.exception))
+ sni_name=hostname)
+ self.assertIn('called a function you should not call', str(e.exception))
with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
with self.assertRaises(ssl.SSLError) as e:
@@ -2539,39 +2572,37 @@ if _have_threads:
if support.verbose:
sys.stdout.write("\n")
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
-
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(SIGNING_CA)
+ client_context, server_context, hostname = testing_context()
tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
- self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
+ self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
# VERIFY_DEFAULT should pass
server = ThreadedEchoServer(context=server_context, chatty=True)
with server:
- with context.wrap_socket(socket.socket()) as s:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
s.connect((HOST, server.port))
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
# VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
- context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
+ client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
server = ThreadedEchoServer(context=server_context, chatty=True)
with server:
- with context.wrap_socket(socket.socket()) as s:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
with self.assertRaisesRegex(ssl.SSLError,
"certificate verify failed"):
s.connect((HOST, server.port))
# now load a CRL file. The CRL file is signed by the CA.
- context.load_verify_locations(CRLFILE)
+ client_context.load_verify_locations(CRLFILE)
server = ThreadedEchoServer(context=server_context, chatty=True)
with server:
- with context.wrap_socket(socket.socket()) as s:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
s.connect((HOST, server.port))
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
@@ -2580,10 +2611,10 @@ if _have_threads:
if support.verbose:
sys.stdout.write("\n")
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(SIGNED_CERTFILE)
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
context.load_verify_locations(SIGNING_CA)
@@ -2619,17 +2650,19 @@ if _have_threads:
)
def test_hostname_checks_common_name(self):
client_context, server_context, hostname = testing_context()
- assert client_context.hostname_checks_common_name
+ if not hasattr(client_context, 'hostname_checks_common_name'):
+ raise unittest.SkipTest('test requires SSLContext having hostname_checks_common_name function.')
+
client_context.hostname_checks_common_name = False
- # default cert has a SAN
+ # default cert has a SAN
server = ThreadedEchoServer(context=server_context, chatty=True)
with server:
with client_context.wrap_socket(socket.socket(),
server_hostname=hostname) as s:
s.connect((HOST, server.port))
- client_context, server_context, hostname = testing_context(NOSANFILE)
+ client_context, server_context, hostname = testing_context(NOSANFILE)
client_context.hostname_checks_common_name = False
server = ThreadedEchoServer(context=server_context, chatty=True)
with server:
@@ -2717,29 +2750,31 @@ if _have_threads:
"OpenSSL is compiled without SSLv2 support")
def test_protocol_sslv2(self):
"""Connecting to an SSLv2 server with various client options"""
+ hostname='localhost'
if support.verbose:
sys.stdout.write("\n")
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False)
# SSLv23 client with specific SSL options
if no_sslv2_implies_sslv3_hello():
# No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False,
client_options=ssl.OP_NO_SSLv2)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False,
client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False,
client_options=ssl.OP_NO_TLSv1)
@skip_if_broken_ubuntu_ssl
@skip_if_openssl_cnf_minprotocol_gt_tls1
def test_protocol_sslv23(self):
"""Connecting to an SSLv23 server with various client options"""
+ hostname='localhost'
if support.verbose:
sys.stdout.write("\n")
if hasattr(ssl, 'PROTOCOL_SSLv2'):
@@ -2751,20 +2786,23 @@ if _have_threads:
sys.stdout.write(
" SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
% str(x))
+
+
+
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3')
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_OPTIONAL)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_REQUIRED)
# Server with specific SSL options
if hasattr(ssl, 'PROTOCOL_SSLv3'):
@@ -2774,7 +2812,7 @@ if _have_threads:
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
- server_options=ssl.OP_NO_TLSv1)
+ server_options=ssl.OP_NO_TLSv1_3)
@skip_if_broken_ubuntu_ssl
@@ -2791,13 +2829,15 @@ if _have_threads:
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS_CLIENT, False)
if no_sslv2_implies_sslv3_hello():
# No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
False, client_options=ssl.OP_NO_SSLv2)
@skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
+ "TLS version 1.1 not supported.")
def test_protocol_tlsv1(self):
"""Connecting to a TLSv1 server with various client options"""
if support.verbose:
@@ -2809,7 +2849,7 @@ if _have_threads:
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
client_options=ssl.OP_NO_TLSv1)
@skip_if_broken_ubuntu_ssl
@@ -2826,12 +2866,12 @@ if _have_threads:
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
client_options=ssl.OP_NO_TLSv1_1)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
@skip_if_broken_ubuntu_ssl
@@ -2849,24 +2889,27 @@ if _have_threads:
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
if hasattr(ssl, 'PROTOCOL_SSLv3'):
try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
client_options=ssl.OP_NO_TLSv1_2)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
+ if hasattr(ssl, "PROTOCOL_TLSv1"):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
+ if hasattr(ssl, "PROTOCOL_TLSv1_1"):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
def test_starttls(self):
"""Switching from clear text to encrypted and back again."""
msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
server = ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
+ ssl_version=ssl.PROTOCOL_TLS_SERVER,
starttls_server=True,
chatty=True,
- connectionchatty=True)
+ connectionchatty=True,
+ check_hostname=False)
wrapped = False
with server:
s = socket.socket()
@@ -2891,7 +2934,8 @@ if _have_threads:
sys.stdout.write(
" client: read %r from server, starting TLS...\n"
% msg)
- conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
+ conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLS_CLIENT,
+ check_hostname=False)
wrapped = True
elif indata == b"ENDTLS" and msg.startswith(b"ok"):
# ENDTLS ok, switch back to clear text
@@ -2918,17 +2962,17 @@ if _have_threads:
def test_socketserver(self):
"""Using socketserver to create and manage SSL connections."""
- server = make_https_server(self, certfile=CERTFILE)
+ server = make_https_server(self, certfile=SIGNED_CERTFILE)
# try to connect
if support.verbose:
sys.stdout.write('\n')
- with open(CERTFILE, 'rb') as f:
+ with open(__file__, 'rb') as f:
d1 = f.read()
d2 = ''
# now fetch the same data from the HTTPS server
url = 'https://localhost:%d/%s' % (
- server.port, os.path.split(CERTFILE)[1])
- context = ssl.create_default_context(cafile=CERTFILE)
+ server.port, os.path.split(__file__)[1])
+ context = ssl.create_default_context(cafile=SIGNING_CA)
f = urllib.request.urlopen(url, context=context)
try:
dlen = f.info().get("content-length")
@@ -2978,7 +3022,7 @@ if _have_threads:
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
+ ssl_version=ssl.PROTOCOL_TLS_SERVER,
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
@@ -2987,8 +3031,7 @@ if _have_threads:
server_side=False,
certfile=CERTFILE,
ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ cert_reqs=ssl.CERT_NONE)
s.connect((HOST, server.port))
# helper methods for standardising recv* method signatures
def _recv_into():
@@ -3130,32 +3173,30 @@ if _have_threads:
def test_nonblocking_send(self):
server = ThreadedEchoServer(CERTFILE,
certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
+ ssl_version=ssl.PROTOCOL_TLS_SERVER,
cacerts=CERTFILE,
chatty=True,
connectionchatty=False)
with server:
- s = test_wrap_socket(socket.socket(),
+ with test_wrap_socket(socket.socket(),
server_side=False,
certfile=CERTFILE,
ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- s.setblocking(False)
+ cert_reqs=ssl.CERT_NONE) as s:
+ s.connect((HOST, server.port))
+ s.setblocking(False)
- # If we keep sending data, at some point the buffers
- # will be full and the call will block
- buf = bytearray(8192)
- def fill_buffer():
- while True:
- s.send(buf)
- self.assertRaises((ssl.SSLWantWriteError,
- ssl.SSLWantReadError), fill_buffer)
+ # If we keep sending data, at some point the buffers
+ # will be full and the call will block
+ buf = bytearray(8192)
+ def fill_buffer():
+ while True:
+ s.send(buf)
+ self.assertRaises((ssl.SSLWantWriteError,
+ ssl.SSLWantReadError), fill_buffer)
- # Now read all the output and discard it
- s.setblocking(True)
- s.close()
+ # Now read all the output and discard it
+ s.setblocking(True)
def test_handshake_timeout(self):
# Issue #5103: SSL handshake must respect the socket timeout
@@ -3286,14 +3327,17 @@ if _have_threads:
Basic tests for SSLSocket.version().
More tests are done in the test_protocol_*() methods.
"""
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ context.check_hostname = False
+ context.verify_mode = ssl.CERT_NONE
with ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
+ ssl_version=ssl.PROTOCOL_TLS_SERVER,
chatty=False) as server:
with context.wrap_socket(socket.socket()) as s:
self.assertIs(s.version(), None)
s.connect((HOST, server.port))
- self.assertEqual(s.version(), 'TLSv1')
+ self.assertEqual(s.version(), 'TLSv1.3')
+ self.assertIs(s._sslobj, None)
self.assertIs(s.version(), None)
@unittest.skipUnless(ssl.HAS_TLSv1_3,
@@ -3306,7 +3350,7 @@ if _have_threads:
ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
)
with ThreadedEchoServer(context=context) as server:
- with context.wrap_socket(socket.socket()) as s:
+ with context.wrap_socket(socket.socket(), server_hostname='localhost') as s:
s.connect((HOST, server.port))
self.assertIn(s.cipher()[0], [
'TLS_AES_256_GCM_SHA384',
@@ -3342,64 +3386,59 @@ if _have_threads:
if support.verbose:
sys.stdout.write("\n")
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
+ client_context, server_context, hostname = testing_context()
+
+ server = ThreadedEchoServer(context=server_context,
chatty=True,
connectionchatty=False)
with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- # get the data
- cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got channel binding data: {0!r}\n"
- .format(cb_data))
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ # get the data
+ cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got channel binding data: {0!r}\n"
+ .format(cb_data))
- # check if it is sane
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
-
- # and compare with the peers version
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(cb_data).encode("us-ascii"))
- s.close()
+ # check if it is sane
+ self.assertIsNotNone(cb_data)
+ if s.version() == 'TLSv1.3':
+ self.assertEqual(len(cb_data), 48)
+ else:
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+
+ # and compare with the peers version
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(cb_data).encode("us-ascii"))
# now, again
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- new_cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got another channel binding data: {0!r}\n"
- .format(new_cb_data))
- # is it really unique
- self.assertNotEqual(cb_data, new_cb_data)
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(new_cb_data).encode("us-ascii"))
- s.close()
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ new_cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got another channel binding data: {0!r}\n"
+ .format(new_cb_data))
+ # is it really unique
+ self.assertNotEqual(cb_data, new_cb_data)
+ self.assertIsNotNone(cb_data)
+ if s.version() == 'TLSv1.3':
+ self.assertEqual(len(cb_data), 48)
+ else:
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(new_cb_data).encode("us-ascii"))
def test_compression(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
+ client_context, server_context, hostname = testing_context()
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True,
+ sni_name=hostname, check_hostname=False)
if support.verbose:
sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
@@ -3407,44 +3446,45 @@ if _have_threads:
@unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
"ssl.OP_NO_COMPRESSION needed for this test")
def test_compression_disabled(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.options |= ssl.OP_NO_COMPRESSION
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
+ client_context, server_context, hostname = testing_context()
+ client_context.options |= ssl.OP_NO_COMPRESSION
+ server_context.options |= ssl.OP_NO_COMPRESSION
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True,
+ sni_name=hostname, check_hostname=False)
self.assertIs(stats['compression'], None)
def test_dh_params(self):
# Check we can get a connection with ephemeral Diffie-Hellman
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.load_dh_params(DHFILE)
- context.set_ciphers("kEDH")
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
+ client_context, server_context, hostname = testing_context()
+ client_context.options |= ssl.OP_NO_TLSv1_3
+ client_context.load_dh_params(DHFILE)
+ client_context.set_ciphers("kEDH")
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True,
+ sni_name=hostname)
cipher = stats["cipher"][0]
parts = cipher.split("-")
if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
- self.fail("Non-DH cipher: " + cipher[0])
+ self.fail(f"Non-DH cipher: {cipher[0]} (parts: {parts})")
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
def test_selected_alpn_protocol(self):
# selected_alpn_protocol() is None unless ALPN is used.
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
+ client_context, server_context, hostname = testing_context()
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True,
+ sni_name=hostname)
self.assertIs(stats['client_alpn_protocol'], None)
@unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
def test_selected_alpn_protocol_if_server_uses_alpn(self):
# selected_alpn_protocol() is None unless ALPN is used by the client.
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.load_verify_locations(CERTFILE)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(CERTFILE)
+ client_context, server_context, hostname = testing_context()
server_context.set_alpn_protocols(['foo', 'bar'])
stats = server_params_test(client_context, server_context,
- chatty=True, connectionchatty=True)
+ chatty=True, connectionchatty=True,
+ sni_name=hostname)
self.assertIs(stats['client_alpn_protocol'], None)
@unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
@@ -3458,10 +3498,10 @@ if _have_threads:
(['http/3.0', 'http/4.0'], None)
]
for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(CERTFILE)
server_context.set_alpn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.load_cert_chain(CERTFILE)
client_context.set_alpn_protocols(client_protocols)
@@ -3490,9 +3530,10 @@ if _have_threads:
self.assertEqual(server_result, expected,
msg % (server_result, "server"))
+ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
def test_selected_npn_protocol(self):
# selected_npn_protocol() is None unless NPN is used
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.load_cert_chain(CERTFILE)
stats = server_params_test(context, context,
chatty=True, connectionchatty=True)
@@ -3508,10 +3549,10 @@ if _have_threads:
(['abc', 'def'], 'abc')
]
for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(CERTFILE)
server_context.set_npn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.load_cert_chain(CERTFILE)
client_context.set_npn_protocols(client_protocols)
stats = server_params_test(client_context, server_context,
@@ -3528,12 +3569,11 @@ if _have_threads:
self.assertEqual(server_result, expected, msg % (server_result, "server"))
def sni_contexts(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(SIGNED_CERTFILE)
- other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
other_context.load_cert_chain(SIGNED_CERTFILE2)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.load_verify_locations(SIGNING_CA)
return server_context, other_context, client_context
@@ -3567,7 +3607,7 @@ if _have_threads:
chatty=True,
sni_name=None)
self.assertEqual(calls, [(None, server_context)])
- self.check_common_name(stats, 'localhost')
+ self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
# Check disabling the callback
calls = []
@@ -3577,7 +3617,7 @@ if _have_threads:
chatty=True,
sni_name='notfunny')
# Certificate didn't change
- self.check_common_name(stats, 'localhost')
+ self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
self.assertEqual(calls, [])
@skip_if_OpenSSL30
@@ -3635,9 +3675,9 @@ if _have_threads:
@unittest.skipIf(IS_OPENSSL_1_1_1, "bpo-36576: fail on OpenSSL 1.1.1")
def test_shared_ciphers(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.verify_mode = ssl.CERT_REQUIRED
client_context.load_verify_locations(SIGNING_CA)
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
@@ -3697,14 +3737,11 @@ if _have_threads:
self.assertEqual(s.recv(1024), TEST_DATA)
def test_session(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
+ client_context, server_context, hostname = testing_context()
# first connection without session
- stats = server_params_test(client_context, server_context)
+ stats = server_params_test(client_context, server_context,
+ sni_name=hostname)
session = stats['session']
self.assertTrue(session.id)
self.assertGreater(session.time, 0)
@@ -3718,7 +3755,8 @@ if _have_threads:
self.assertEqual(sess_stat['hits'], 0)
# reuse session
- stats = server_params_test(client_context, server_context, session=session)
+ stats = server_params_test(client_context, server_context,
+ session=session, sni_name=hostname)
sess_stat = server_context.session_stats()
self.assertEqual(sess_stat['accept'], 2)
self.assertEqual(sess_stat['hits'], 1)
@@ -3731,7 +3769,8 @@ if _have_threads:
self.assertGreaterEqual(session2.timeout, session.timeout)
# another one without session
- stats = server_params_test(client_context, server_context)
+ stats = server_params_test(client_context, server_context,
+ sni_name=hostname)
self.assertFalse(stats['session_reused'])
session3 = stats['session']
self.assertNotEqual(session3.id, session.id)
@@ -3741,7 +3780,8 @@ if _have_threads:
self.assertEqual(sess_stat['hits'], 1)
# reuse session again
- stats = server_params_test(client_context, server_context, session=session)
+ stats = server_params_test(client_context, server_context, session=session,
+ sni_name=hostname)
self.assertTrue(stats['session_reused'])
session4 = stats['session']
self.assertEqual(session4.id, session.id)