diff --git a/python36.changes b/python36.changes index aed3415..52c1a07 100644 --- a/python36.changes +++ b/python36.changes @@ -26,6 +26,8 @@ Thu Jan 11 15:14:09 UTC 2024 - Matej Cepl (from gh#python/cpython!100373/files) to stopping SSLContext.load_verify_locations from accepting some cases of trailing data in DER. +- Add switch-to-PROTOCOL_TLS_CLIENT.patch switching to + PROTOCOL_TLS_CLIENT for testing. ------------------------------------------------------------------- Mon Sep 11 06:28:43 UTC 2023 - Daniel Garcia diff --git a/python36.spec b/python36.spec index d2aea2a..2c2c103 100644 --- a/python36.spec +++ b/python36.spec @@ -257,6 +257,9 @@ Patch63: bpo43920-fix-load_verify_locations-errmsgs.patch # PATCH-FIX-UPSTREAM gh100372-SSLContext_load_verify_locations-trailing-data.patch bsc#1217782 mcepl@suse.com # SSLContext.load_verify_locations stop accepting some cases of trailing data in DER (from gh#python/cpython!100373) Patch64: gh100372-SSLContext_load_verify_locations-trailing-data.patch +# PATCH-FIX-UPSTREAM switch-to-PROTOCOL_TLS_CLIENT.patch bsc#1217782 mcepl@suse.com +# switching to PROTOCOL_TLS settings for testing +Patch65: switch-to-PROTOCOL_TLS_CLIENT.patch BuildRequires: automake BuildRequires: fdupes BuildRequires: gmp-devel @@ -560,6 +563,7 @@ other applications. %patch -P 62 -p1 %patch -P 63 -p1 %patch -P 64 -p1 +%patch -P 65 -p1 # drop Autoconf version requirement sed -i 's/^AC_PREREQ/dnl AC_PREREQ/' configure.ac diff --git a/switch-to-PROTOCOL_TLS_CLIENT.patch b/switch-to-PROTOCOL_TLS_CLIENT.patch new file mode 100644 index 0000000..3dc1a03 --- /dev/null +++ b/switch-to-PROTOCOL_TLS_CLIENT.patch @@ -0,0 +1,1199 @@ +--- + Lib/test/test_ssl.py | 516 +++++++++++++++++++++++++++------------------------ + 1 file changed, 282 insertions(+), 234 deletions(-) + +Index: Python-3.6.15/Lib/test/test_ssl.py +=================================================================== +--- Python-3.6.15.orig/Lib/test/test_ssl.py ++++ Python-3.6.15/Lib/test/test_ssl.py +@@ -22,6 +22,7 @@ import weakref + import platform + import re + import functools ++import warnings + try: + import ctypes + except ImportError: +@@ -71,6 +72,7 @@ CRLFILE = data_file("revocation.crl") + + # Two keys and certs signed by the same CA (for SNI tests) + SIGNED_CERTFILE = data_file("keycert3.pem") ++SIGNED_CERTFILE_HOSTNAME = 'localhost' + SIGNED_CERTFILE2 = data_file("keycert4.pem") + # Same certificate as pycacert.pem, but without extra text in file + SIGNING_CA = data_file("capath", "ceff1710.0") +@@ -203,15 +205,17 @@ def skip_if_openssl_cnf_minprotocol_gt_t + return f + + +- + needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") + + + def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *, + cert_reqs=ssl.CERT_NONE, ca_certs=None, + ciphers=None, certfile=None, keyfile=None, ++ check_hostname=None, + **kwargs): + context = ssl.SSLContext(ssl_version) ++ if check_hostname is not None: ++ context.check_hostname = check_hostname + if cert_reqs is not None: + context.verify_mode = cert_reqs + if ca_certs is not None: +@@ -222,6 +226,30 @@ def test_wrap_socket(sock, ssl_version=s + context.set_ciphers(ciphers) + return context.wrap_socket(sock, **kwargs) + ++def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True): ++ """Create context ++ ++ client_context, server_context, hostname = testing_context() ++ """ ++ if server_cert == SIGNED_CERTFILE: ++ hostname = SIGNED_CERTFILE_HOSTNAME ++ elif server_cert == SIGNED_CERTFILE2: ++ hostname = SIGNED_CERTFILE2_HOSTNAME ++ elif server_cert == NOSANFILE: ++ hostname = NOSAN_HOSTNAME ++ else: ++ raise ValueError(server_cert) ++ ++ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ client_context.load_verify_locations(SIGNING_CA) ++ ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ++ server_context.load_cert_chain(server_cert) ++ if server_chain: ++ server_context.load_verify_locations(SIGNING_CA) ++ ++ return client_context, server_context, hostname ++ + class BasicSocketTests(unittest.TestCase): + + def test_constants(self): +@@ -543,7 +571,8 @@ class BasicSocketTests(unittest.TestCase + with self.assertRaises(ssl.SSLError): + test_wrap_socket(sock, + certfile=certfile, +- ssl_version=ssl.PROTOCOL_TLSv1) ++ check_hostname=False, ++ ssl_version=ssl.PROTOCOL_TLS_CLIENT) + + def test_empty_cert(self): + """Wrapping with an empty cert file""" +@@ -972,7 +1001,7 @@ class ContextTests(unittest.TestCase): + self.assertEqual(ctx.protocol, proto) + + def test_ciphers(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.set_ciphers("ALL") + ctx.set_ciphers("DEFAULT") + with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"): +@@ -980,7 +1009,7 @@ class ContextTests(unittest.TestCase): + + @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old') + def test_get_ciphers(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.set_ciphers('AESGCM') + names = set(d['name'] for d in ctx.get_ciphers()) + self.assertIn('AES256-GCM-SHA384', names) +@@ -1012,6 +1041,7 @@ class ContextTests(unittest.TestCase): + + def test_verify_mode(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx.check_hostname = False + # Default value + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + ctx.verify_mode = ssl.CERT_OPTIONAL +@@ -1028,7 +1058,7 @@ class ContextTests(unittest.TestCase): + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_verify_flags(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # default value + tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) + self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf) +@@ -1046,7 +1076,7 @@ class ContextTests(unittest.TestCase): + ctx.verify_flags = None + + def test_load_cert_chain(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # Combined key and cert in a single file + ctx.load_cert_chain(CERTFILE, keyfile=None) + ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) +@@ -1059,7 +1089,7 @@ class ContextTests(unittest.TestCase): + with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): + ctx.load_cert_chain(EMPTYCERT) + # Separate key and cert +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_cert_chain(ONLYCERT, ONLYKEY) + ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY) + ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY) +@@ -1070,7 +1100,7 @@ class ContextTests(unittest.TestCase): + with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): + ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT) + # Mismatching key and cert +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"): + ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY) + # Password protected key and cert +@@ -1129,7 +1159,7 @@ class ContextTests(unittest.TestCase): + ctx.load_cert_chain(CERTFILE, password=getpass_exception) + + def test_load_verify_locations(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(CERTFILE) + ctx.load_verify_locations(cafile=CERTFILE, capath=None) + ctx.load_verify_locations(BYTES_CERTFILE) +@@ -1157,7 +1187,7 @@ class ContextTests(unittest.TestCase): + neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem) + + # test PEM +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0) + ctx.load_verify_locations(cadata=cacert_pem) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1) +@@ -1168,20 +1198,20 @@ class ContextTests(unittest.TestCase): + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + combined = "\n".join((cacert_pem, neuronio_pem)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # with junk around the certs +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + combined = ["head", cacert_pem, "other", neuronio_pem, "again", + neuronio_pem, "tail"] + ctx.load_verify_locations(cadata="\n".join(combined)) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # test DER +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(cadata=cacert_der) + ctx.load_verify_locations(cadata=neuronio_der) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) +@@ -1190,13 +1220,13 @@ class ContextTests(unittest.TestCase): + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # combined +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + combined = b"".join((cacert_der, neuronio_der)) + ctx.load_verify_locations(cadata=combined) + self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) + + # error cases +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object) + + with self.assertRaisesRegex( +@@ -1214,7 +1244,7 @@ class ContextTests(unittest.TestCase): + + + def test_load_dh_params(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_dh_params(DHFILE) + if os.name != 'nt': + ctx.load_dh_params(BYTES_DHFILE) +@@ -1247,12 +1277,12 @@ class ContextTests(unittest.TestCase): + def test_set_default_verify_paths(self): + # There's not much we can do to test that it acts as expected, + # so just check it doesn't crash or raise an exception. +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.set_default_verify_paths() + + @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build") + def test_set_ecdh_curve(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.set_ecdh_curve("prime256v1") + ctx.set_ecdh_curve(b"prime256v1") + self.assertRaises(TypeError, ctx.set_ecdh_curve) +@@ -1262,7 +1292,8 @@ class ContextTests(unittest.TestCase): + + @needs_sni + def test_sni_callback(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ++ ctx.check_hostname = False + + # set_servername_callback expects a callable, or None + self.assertRaises(TypeError, ctx.set_servername_callback) +@@ -1279,7 +1310,7 @@ class ContextTests(unittest.TestCase): + def test_sni_callback_refcycle(self): + # Reference cycles through the servername callback are detected + # and cleared. +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + def dummycallback(sock, servername, ctx, cycle=ctx): + pass + ctx.set_servername_callback(dummycallback) +@@ -1289,7 +1320,7 @@ class ContextTests(unittest.TestCase): + self.assertIs(wr(), None) + + def test_cert_store_stats(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.cert_store_stats(), + {'x509_ca': 0, 'crl': 0, 'x509': 0}) + ctx.load_cert_chain(CERTFILE) +@@ -1303,7 +1334,7 @@ class ContextTests(unittest.TestCase): + {'x509_ca': 1, 'crl': 0, 'x509': 2}) + + def test_get_ca_certs(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.get_ca_certs(), []) + # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE + ctx.load_verify_locations(CERTFILE) +@@ -1331,24 +1362,24 @@ class ContextTests(unittest.TestCase): + self.assertEqual(ctx.get_ca_certs(True), [der]) + + def test_load_default_certs(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_default_certs() + +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) + ctx.load_default_certs() + +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) + +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertRaises(TypeError, ctx.load_default_certs, None) + self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') + + @unittest.skipIf(sys.platform == "win32", "not-Windows specific") + @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars") + def test_load_default_certs_env(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + with support.EnvironmentVarGuard() as env: + env["SSL_CERT_DIR"] = CAPATH + env["SSL_CERT_FILE"] = CERTFILE +@@ -1357,11 +1388,11 @@ class ContextTests(unittest.TestCase): + + @unittest.skipUnless(sys.platform == "win32", "Windows specific") + def test_load_default_certs_env_windows(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_default_certs() + stats = ctx.cert_store_stats() + +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + with support.EnvironmentVarGuard() as env: + env["SSL_CERT_DIR"] = CAPATH + env["SSL_CERT_FILE"] = CERTFILE +@@ -1408,20 +1439,20 @@ class ContextTests(unittest.TestCase): + + def test__create_stdlib_context(self): + ctx = ssl._create_stdlib_context() +- self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) ++ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) + self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + self.assertFalse(ctx.check_hostname) + self._assert_context_options(ctx) + +- ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1) +- self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) +- self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) ++ ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT) ++ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT) ++ self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self._assert_context_options(ctx) + +- ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1, ++ ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT, + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=True) +- self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) ++ self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertTrue(ctx.check_hostname) + self._assert_context_options(ctx) +@@ -1432,7 +1463,7 @@ class ContextTests(unittest.TestCase): + self._assert_context_options(ctx) + + def test_check_hostname(self): +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) + self.assertFalse(ctx.check_hostname) + + # Requires CERT_REQUIRED or CERT_OPTIONAL +@@ -1479,7 +1510,7 @@ class SSLErrorTests(unittest.TestCase): + + def test_lib_reason(self): + # Test the library and reason attributes +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + with self.assertRaises(ssl.SSLError) as cm: + ctx.load_dh_params(CERTFILE) + self.assertEqual(cm.exception.library, 'PEM') +@@ -1490,7 +1521,9 @@ class SSLErrorTests(unittest.TestCase): + def test_subclass(self): + # Check that the appropriate SSLError subclass is raised + # (this only tests one of them) +- ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ ctx.check_hostname = False ++ ctx.verify_mode = ssl.CERT_NONE + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + s.listen() +@@ -1642,7 +1675,9 @@ class SimpleBackgroundTests(unittest.Tes + + def test_connect_with_context(self): + # Same as test_connect, but with a separately created context +- ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ++ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ ctx.check_hostname = False ++ ctx.verify_mode = ssl.CERT_NONE + with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: + s.connect(self.server_addr) + self.assertEqual({}, s.getpeercert()) +@@ -1790,10 +1825,12 @@ class SimpleBackgroundTests(unittest.Tes + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), "needs TLS 1.2") + def test_context_setget(self): + # Check that the context of a connected socket can be replaced. +- ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) +- ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ++ ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ ctx1.load_verify_locations(capath=CAPATH) ++ ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ ctx2.load_verify_locations(capath=CAPATH) + s = socket.socket(socket.AF_INET) +- with ctx1.wrap_socket(s) as ss: ++ with ctx1.wrap_socket(s, server_hostname='localhost') as ss: + ss.connect(self.server_addr) + self.assertIs(ss.context, ctx1) + self.assertIs(ss._sslobj.context, ctx1) +@@ -2121,24 +2158,26 @@ if _have_threads: + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + npn_protocols=None, alpn_protocols=None, +- ciphers=None, context=None): ++ ciphers=None, context=None, check_hostname=None): + if context: + self.context = context + else: + self.context = ssl.SSLContext(ssl_version + if ssl_version is not None + else ssl.PROTOCOL_TLS) ++ if check_hostname is not None: ++ self.context.check_hostname = check_hostname + self.context.verify_mode = (certreqs if certreqs is not None + else ssl.CERT_NONE) +- if cacerts: ++ if cacerts is not None: + self.context.load_verify_locations(cacerts) +- if certificate: ++ if certificate is not None: + self.context.load_cert_chain(certificate) +- if npn_protocols: ++ if npn_protocols is not None: + self.context.set_npn_protocols(npn_protocols) +- if alpn_protocols: ++ if alpn_protocols is not None: + self.context.set_alpn_protocols(alpn_protocols) +- if ciphers: ++ if ciphers is not None: + self.context.set_ciphers(ciphers) + self.chatty = chatty + self.connectionchatty = connectionchatty +@@ -2317,7 +2356,7 @@ if _have_threads: + + def server_params_test(client_context, server_context, indata=b"FOO\n", + chatty=True, connectionchatty=False, sni_name=None, +- session=None): ++ session=None, check_hostname=None): + """ + Launch a server, connect a client to it and try various reads + and writes. +@@ -2325,7 +2364,8 @@ if _have_threads: + stats = {} + server = ThreadedEchoServer(context=server_context, + chatty=chatty, +- connectionchatty=False) ++ connectionchatty=False, ++ check_hostname=check_hostname) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name, session=session) as s: +@@ -2387,18 +2427,24 @@ if _have_threads: + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) +- client_context = ssl.SSLContext(client_protocol) +- client_context.options |= client_options +- server_context = ssl.SSLContext(server_protocol) +- server_context.options |= server_options ++ ++ with warnings.catch_warnings(): ++ warnings.simplefilter('ignore') ++ client_context = ssl.SSLContext(client_protocol) ++ client_context.options |= client_options ++ server_context = ssl.SSLContext(server_protocol) ++ server_context.options |= server_options ++ if certsreqs == ssl.CERT_NONE: ++ server_context.check_hostname = False + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). +- if client_context.protocol == ssl.PROTOCOL_SSLv23: ++ if client_context.protocol == ssl.PROTOCOL_TLS: + client_context.set_ciphers("ALL") + + for ctx in (client_context, server_context): ++ ctx.check_hostname = False + ctx.verify_mode = certsreqs + ctx.load_cert_chain(CERTFILE) + ctx.load_verify_locations(CERTFILE) +@@ -2432,26 +2478,22 @@ if _have_threads: + """Basic test of an SSL client connecting to a server""" + if support.verbose: + sys.stdout.write("\n") +- for protocol in PROTOCOLS: +- if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: +- continue +- with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): +- context = ssl.SSLContext(protocol) +- context.load_cert_chain(CERTFILE) +- server_params_test(context, context, +- chatty=True, connectionchatty=True) ++ # for protocol in PROTOCOLS: ++ # if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: ++ # continue ++ # with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): ++ # context = ssl.SSLContext(protocol) ++ # context.load_cert_chain(CERTFILE) ++ # server_params_test(context, context, ++ # chatty=True, connectionchatty=True) + +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +- client_context.load_verify_locations(SIGNING_CA) +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) +- # server_context.load_verify_locations(SIGNING_CA) +- server_context.load_cert_chain(SIGNED_CERTFILE2) ++ client_context, server_context, hostname = testing_context() + + with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): + server_params_test(client_context=client_context, + server_context=server_context, + chatty=True, connectionchatty=True, +- sni_name='fakehostname') ++ sni_name=hostname) + + client_context.check_hostname = False + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): +@@ -2459,9 +2501,8 @@ if _have_threads: + server_params_test(client_context=server_context, + server_context=client_context, + chatty=True, connectionchatty=True, +- sni_name='fakehostname') +- self.assertIn('called a function you should not call', +- str(e.exception)) ++ sni_name=hostname) ++ self.assertIn('called a function you should not call', str(e.exception)) + + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): + with self.assertRaises(ssl.SSLError) as e: +@@ -2524,39 +2565,37 @@ if _have_threads: + if support.verbose: + sys.stdout.write("\n") + +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- server_context.load_cert_chain(SIGNED_CERTFILE) +- +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- context.verify_mode = ssl.CERT_REQUIRED +- context.load_verify_locations(SIGNING_CA) ++ client_context, server_context, hostname = testing_context() + tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) +- self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) ++ self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf) + + # VERIFY_DEFAULT should pass + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: +- with context.wrap_socket(socket.socket()) as s: ++ with client_context.wrap_socket(socket.socket(), ++ server_hostname=hostname) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails +- context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF ++ client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: +- with context.wrap_socket(socket.socket()) as s: ++ with client_context.wrap_socket(socket.socket(), ++ server_hostname=hostname) as s: + with self.assertRaisesRegex(ssl.SSLError, + "certificate verify failed"): + s.connect((HOST, server.port)) + + # now load a CRL file. The CRL file is signed by the CA. +- context.load_verify_locations(CRLFILE) ++ client_context.load_verify_locations(CRLFILE) + + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: +- with context.wrap_socket(socket.socket()) as s: ++ with client_context.wrap_socket(socket.socket(), ++ server_hostname=hostname) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") +@@ -2565,10 +2604,10 @@ if _have_threads: + if support.verbose: + sys.stdout.write("\n") + +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(SIGNED_CERTFILE) + +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) +@@ -2604,17 +2643,19 @@ if _have_threads: + ) + def test_hostname_checks_common_name(self): + client_context, server_context, hostname = testing_context() +- assert client_context.hostname_checks_common_name ++ if not hasattr(client_context, 'hostname_checks_common_name'): ++ raise unittest.SkipTest('test requires SSLContext having hostname_checks_common_name function.') ++ + client_context.hostname_checks_common_name = False + +- # default cert has a SAN ++ # default cert has a SAN + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + +- client_context, server_context, hostname = testing_context(NOSANFILE) ++ client_context, server_context, hostname = testing_context(NOSANFILE) + client_context.hostname_checks_common_name = False + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: +@@ -2702,29 +2743,31 @@ if _have_threads: + "OpenSSL is compiled without SSLv2 support") + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" ++ hostname='localhost' + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) +- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) ++ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) +- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) ++ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False) + # SSLv23 client with specific SSL options + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs +- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, + client_options=ssl.OP_NO_SSLv2) +- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, + client_options=ssl.OP_NO_SSLv3) +- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, + client_options=ssl.OP_NO_TLSv1) + + @skip_if_broken_ubuntu_ssl + @skip_if_openssl_cnf_minprotocol_gt_tls1 + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" ++ hostname='localhost' + if support.verbose: + sys.stdout.write("\n") + if hasattr(ssl, 'PROTOCOL_SSLv2'): +@@ -2736,20 +2779,23 @@ if _have_threads: + sys.stdout.write( + " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" + % str(x)) ++ ++ ++ + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) +- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') ++ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3') + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) +- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) ++ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_OPTIONAL) + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) +- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) ++ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_REQUIRED) + + # Server with specific SSL options + if hasattr(ssl, 'PROTOCOL_SSLv3'): +@@ -2759,7 +2805,7 @@ if _have_threads: + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, + server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, +- server_options=ssl.OP_NO_TLSv1) ++ server_options=ssl.OP_NO_TLSv1_3) + + + @skip_if_broken_ubuntu_ssl +@@ -2776,13 +2822,15 @@ if _have_threads: + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv3) +- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) ++ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS_CLIENT, False) + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, + False, client_options=ssl.OP_NO_SSLv2) + + @skip_if_broken_ubuntu_ssl ++ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), ++ "TLS version 1.1 not supported.") + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" + if support.verbose: +@@ -2794,7 +2842,7 @@ if _have_threads: + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False, + client_options=ssl.OP_NO_TLSv1) + + @skip_if_broken_ubuntu_ssl +@@ -2811,12 +2859,12 @@ if _have_threads: + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False, + client_options=ssl.OP_NO_TLSv1_1) + +- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') +- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) ++ try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + + + @skip_if_broken_ubuntu_ssl +@@ -2834,24 +2882,27 @@ if _have_threads: + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False, + client_options=ssl.OP_NO_TLSv1_2) + +- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') +- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) +- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) ++ try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') ++ if hasattr(ssl, "PROTOCOL_TLSv1"): ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) ++ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) ++ if hasattr(ssl, "PROTOCOL_TLSv1_1"): ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) ++ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") + + server = ThreadedEchoServer(CERTFILE, +- ssl_version=ssl.PROTOCOL_TLSv1, ++ ssl_version=ssl.PROTOCOL_TLS_SERVER, + starttls_server=True, + chatty=True, +- connectionchatty=True) ++ connectionchatty=True, ++ check_hostname=False) + wrapped = False + with server: + s = socket.socket() +@@ -2876,7 +2927,8 @@ if _have_threads: + sys.stdout.write( + " client: read %r from server, starting TLS...\n" + % msg) +- conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) ++ conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLS_CLIENT, ++ check_hostname=False) + wrapped = True + elif indata == b"ENDTLS" and msg.startswith(b"ok"): + # ENDTLS ok, switch back to clear text +@@ -2903,17 +2955,17 @@ if _have_threads: + + def test_socketserver(self): + """Using socketserver to create and manage SSL connections.""" +- server = make_https_server(self, certfile=CERTFILE) ++ server = make_https_server(self, certfile=SIGNED_CERTFILE) + # try to connect + if support.verbose: + sys.stdout.write('\n') +- with open(CERTFILE, 'rb') as f: ++ with open(__file__, 'rb') as f: + d1 = f.read() + d2 = '' + # now fetch the same data from the HTTPS server + url = 'https://localhost:%d/%s' % ( +- server.port, os.path.split(CERTFILE)[1]) +- context = ssl.create_default_context(cafile=CERTFILE) ++ server.port, os.path.split(__file__)[1]) ++ context = ssl.create_default_context(cafile=SIGNING_CA) + f = urllib.request.urlopen(url, context=context) + try: + dlen = f.info().get("content-length") +@@ -2963,7 +3015,7 @@ if _have_threads: + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1, ++ ssl_version=ssl.PROTOCOL_TLS_SERVER, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) +@@ -2972,8 +3024,7 @@ if _have_threads: + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, +- cert_reqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1) ++ cert_reqs=ssl.CERT_NONE) + s.connect((HOST, server.port)) + # helper methods for standardising recv* method signatures + def _recv_into(): +@@ -3115,32 +3166,30 @@ if _have_threads: + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1, ++ ssl_version=ssl.PROTOCOL_TLS_SERVER, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: +- s = test_wrap_socket(socket.socket(), ++ with test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, +- cert_reqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1) +- s.connect((HOST, server.port)) +- s.setblocking(False) ++ cert_reqs=ssl.CERT_NONE) as s: ++ s.connect((HOST, server.port)) ++ s.setblocking(False) + +- # If we keep sending data, at some point the buffers +- # will be full and the call will block +- buf = bytearray(8192) +- def fill_buffer(): +- while True: +- s.send(buf) +- self.assertRaises((ssl.SSLWantWriteError, +- ssl.SSLWantReadError), fill_buffer) ++ # If we keep sending data, at some point the buffers ++ # will be full and the call will block ++ buf = bytearray(8192) ++ def fill_buffer(): ++ while True: ++ s.send(buf) ++ self.assertRaises((ssl.SSLWantWriteError, ++ ssl.SSLWantReadError), fill_buffer) + +- # Now read all the output and discard it +- s.setblocking(True) +- s.close() ++ # Now read all the output and discard it ++ s.setblocking(True) + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout +@@ -3271,14 +3320,17 @@ if _have_threads: + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ++ context.check_hostname = False ++ context.verify_mode = ssl.CERT_NONE + with ThreadedEchoServer(CERTFILE, +- ssl_version=ssl.PROTOCOL_TLSv1, ++ ssl_version=ssl.PROTOCOL_TLS_SERVER, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) + s.connect((HOST, server.port)) +- self.assertEqual(s.version(), 'TLSv1') ++ self.assertEqual(s.version(), 'TLSv1.3') ++ self.assertIs(s._sslobj, None) + self.assertIs(s.version(), None) + + @unittest.skipUnless(ssl.HAS_TLSv1_3, +@@ -3291,7 +3343,7 @@ if _have_threads: + ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2 + ) + with ThreadedEchoServer(context=context) as server: +- with context.wrap_socket(socket.socket()) as s: ++ with context.wrap_socket(socket.socket(), server_hostname='localhost') as s: + s.connect((HOST, server.port)) + self.assertIn(s.cipher()[0], [ + 'TLS_AES_256_GCM_SHA384', +@@ -3327,64 +3379,59 @@ if _have_threads: + if support.verbose: + sys.stdout.write("\n") + +- server = ThreadedEchoServer(CERTFILE, +- certreqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1, +- cacerts=CERTFILE, ++ client_context, server_context, hostname = testing_context() ++ ++ server = ThreadedEchoServer(context=server_context, + chatty=True, + connectionchatty=False) + with server: +- s = test_wrap_socket(socket.socket(), +- server_side=False, +- certfile=CERTFILE, +- ca_certs=CERTFILE, +- cert_reqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1) +- s.connect((HOST, server.port)) +- # get the data +- cb_data = s.get_channel_binding("tls-unique") +- if support.verbose: +- sys.stdout.write(" got channel binding data: {0!r}\n" +- .format(cb_data)) ++ with client_context.wrap_socket(socket.socket(), ++ server_hostname=hostname) as s: ++ s.connect((HOST, server.port)) ++ # get the data ++ cb_data = s.get_channel_binding("tls-unique") ++ if support.verbose: ++ sys.stdout.write(" got channel binding data: {0!r}\n" ++ .format(cb_data)) + +- # check if it is sane +- self.assertIsNotNone(cb_data) +- self.assertEqual(len(cb_data), 12) # True for TLSv1 +- +- # and compare with the peers version +- s.write(b"CB tls-unique\n") +- peer_data_repr = s.read().strip() +- self.assertEqual(peer_data_repr, +- repr(cb_data).encode("us-ascii")) +- s.close() ++ # check if it is sane ++ self.assertIsNotNone(cb_data) ++ if s.version() == 'TLSv1.3': ++ self.assertEqual(len(cb_data), 48) ++ else: ++ self.assertEqual(len(cb_data), 12) # True for TLSv1 ++ ++ # and compare with the peers version ++ s.write(b"CB tls-unique\n") ++ peer_data_repr = s.read().strip() ++ self.assertEqual(peer_data_repr, ++ repr(cb_data).encode("us-ascii")) + + # now, again +- s = test_wrap_socket(socket.socket(), +- server_side=False, +- certfile=CERTFILE, +- ca_certs=CERTFILE, +- cert_reqs=ssl.CERT_NONE, +- ssl_version=ssl.PROTOCOL_TLSv1) +- s.connect((HOST, server.port)) +- new_cb_data = s.get_channel_binding("tls-unique") +- if support.verbose: +- sys.stdout.write(" got another channel binding data: {0!r}\n" +- .format(new_cb_data)) +- # is it really unique +- self.assertNotEqual(cb_data, new_cb_data) +- self.assertIsNotNone(cb_data) +- self.assertEqual(len(cb_data), 12) # True for TLSv1 +- s.write(b"CB tls-unique\n") +- peer_data_repr = s.read().strip() +- self.assertEqual(peer_data_repr, +- repr(new_cb_data).encode("us-ascii")) +- s.close() ++ with client_context.wrap_socket(socket.socket(), ++ server_hostname=hostname) as s: ++ s.connect((HOST, server.port)) ++ new_cb_data = s.get_channel_binding("tls-unique") ++ if support.verbose: ++ sys.stdout.write(" got another channel binding data: {0!r}\n" ++ .format(new_cb_data)) ++ # is it really unique ++ self.assertNotEqual(cb_data, new_cb_data) ++ self.assertIsNotNone(cb_data) ++ if s.version() == 'TLSv1.3': ++ self.assertEqual(len(cb_data), 48) ++ else: ++ self.assertEqual(len(cb_data), 12) # True for TLSv1 ++ s.write(b"CB tls-unique\n") ++ peer_data_repr = s.read().strip() ++ self.assertEqual(peer_data_repr, ++ repr(new_cb_data).encode("us-ascii")) + + def test_compression(self): +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- context.load_cert_chain(CERTFILE) +- stats = server_params_test(context, context, +- chatty=True, connectionchatty=True) ++ client_context, server_context, hostname = testing_context() ++ stats = server_params_test(client_context, server_context, ++ chatty=True, connectionchatty=True, ++ sni_name=hostname, check_hostname=False) + if support.verbose: + sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) + self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) +@@ -3392,44 +3439,45 @@ if _have_threads: + @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), + "ssl.OP_NO_COMPRESSION needed for this test") + def test_compression_disabled(self): +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- context.load_cert_chain(CERTFILE) +- context.options |= ssl.OP_NO_COMPRESSION +- stats = server_params_test(context, context, +- chatty=True, connectionchatty=True) ++ client_context, server_context, hostname = testing_context() ++ client_context.options |= ssl.OP_NO_COMPRESSION ++ server_context.options |= ssl.OP_NO_COMPRESSION ++ stats = server_params_test(client_context, server_context, ++ chatty=True, connectionchatty=True, ++ sni_name=hostname, check_hostname=False) + self.assertIs(stats['compression'], None) + + def test_dh_params(self): + # Check we can get a connection with ephemeral Diffie-Hellman +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- context.load_cert_chain(CERTFILE) +- context.load_dh_params(DHFILE) +- context.set_ciphers("kEDH") +- stats = server_params_test(context, context, +- chatty=True, connectionchatty=True) ++ client_context, server_context, hostname = testing_context() ++ client_context.options = ssl.PROTOCOL_TLS & ssl.OP_NO_TLSv1_3 ++ client_context.load_dh_params(DHFILE) ++ client_context.set_ciphers("kEDH") ++ stats = server_params_test(client_context, server_context, ++ chatty=True, connectionchatty=True, ++ sni_name=hostname) + cipher = stats["cipher"][0] + parts = cipher.split("-") + if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: +- self.fail("Non-DH cipher: " + cipher[0]) ++ self.fail(f"Non-DH cipher: {cipher[0]} (parts: {parts})") + ++ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol(self): + # selected_alpn_protocol() is None unless ALPN is used. +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- context.load_cert_chain(CERTFILE) +- stats = server_params_test(context, context, +- chatty=True, connectionchatty=True) ++ client_context, server_context, hostname = testing_context() ++ stats = server_params_test(client_context, server_context, ++ chatty=True, connectionchatty=True, ++ sni_name=hostname) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol_if_server_uses_alpn(self): + # selected_alpn_protocol() is None unless ALPN is used by the client. +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- client_context.load_verify_locations(CERTFILE) +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- server_context.load_cert_chain(CERTFILE) ++ client_context, server_context, hostname = testing_context() + server_context.set_alpn_protocols(['foo', 'bar']) + stats = server_params_test(client_context, server_context, +- chatty=True, connectionchatty=True) ++ chatty=True, connectionchatty=True, ++ sni_name=hostname) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") +@@ -3443,10 +3491,10 @@ if _have_threads: + (['http/3.0', 'http/4.0'], None) + ] + for client_protocols, expected in protocol_tests: +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(server_protocols) +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) ++ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_cert_chain(CERTFILE) + client_context.set_alpn_protocols(client_protocols) + +@@ -3475,9 +3523,10 @@ if _have_threads: + self.assertEqual(server_result, expected, + msg % (server_result, "server")) + ++ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") + def test_selected_npn_protocol(self): + # selected_npn_protocol() is None unless NPN is used +- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) +@@ -3493,10 +3542,10 @@ if _have_threads: + (['abc', 'def'], 'abc') + ] + for client_protocols, expected in protocol_tests: +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(CERTFILE) + server_context.set_npn_protocols(server_protocols) +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_cert_chain(CERTFILE) + client_context.set_npn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, +@@ -3513,12 +3562,11 @@ if _have_threads: + self.assertEqual(server_result, expected, msg % (server_result, "server")) + + def sni_contexts(self): +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(SIGNED_CERTFILE) +- other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + other_context.load_cert_chain(SIGNED_CERTFILE2) +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- client_context.verify_mode = ssl.CERT_REQUIRED ++ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + +@@ -3552,7 +3600,7 @@ if _have_threads: + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) +- self.check_common_name(stats, 'localhost') ++ self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) + + # Check disabling the callback + calls = [] +@@ -3562,7 +3610,7 @@ if _have_threads: + chatty=True, + sni_name='notfunny') + # Certificate didn't change +- self.check_common_name(stats, 'localhost') ++ self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) + self.assertEqual(calls, []) + + @skip_if_OpenSSL30 +@@ -3620,9 +3668,9 @@ if _have_threads: + + @unittest.skipIf(IS_OPENSSL_1_1_1, "bpo-36576: fail on OpenSSL 1.1.1") + def test_shared_ciphers(self): +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(SIGNED_CERTFILE) +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) ++ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): +@@ -3682,14 +3730,11 @@ if _have_threads: + self.assertEqual(s.recv(1024), TEST_DATA) + + def test_session(self): +- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- server_context.load_cert_chain(SIGNED_CERTFILE) +- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +- client_context.verify_mode = ssl.CERT_REQUIRED +- client_context.load_verify_locations(SIGNING_CA) ++ client_context, server_context, hostname = testing_context() + + # first connection without session +- stats = server_params_test(client_context, server_context) ++ stats = server_params_test(client_context, server_context, ++ sni_name=hostname) + session = stats['session'] + self.assertTrue(session.id) + self.assertGreater(session.time, 0) +@@ -3703,7 +3748,8 @@ if _have_threads: + self.assertEqual(sess_stat['hits'], 0) + + # reuse session +- stats = server_params_test(client_context, server_context, session=session) ++ stats = server_params_test(client_context, server_context, ++ session=session, sni_name=hostname) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 2) + self.assertEqual(sess_stat['hits'], 1) +@@ -3716,7 +3762,8 @@ if _have_threads: + self.assertGreaterEqual(session2.timeout, session.timeout) + + # another one without session +- stats = server_params_test(client_context, server_context) ++ stats = server_params_test(client_context, server_context, ++ sni_name=hostname) + self.assertFalse(stats['session_reused']) + session3 = stats['session'] + self.assertNotEqual(session3.id, session.id) +@@ -3726,7 +3773,8 @@ if _have_threads: + self.assertEqual(sess_stat['hits'], 1) + + # reuse session again +- stats = server_params_test(client_context, server_context, session=session) ++ stats = server_params_test(client_context, server_context, session=session, ++ sni_name=hostname) + self.assertTrue(stats['session_reused']) + session4 = stats['session'] + self.assertEqual(session4.id, session.id)