--- Lib/test/test_ssl.py | 510 +++++++++++++++++++++++++++------------------------ 1 file changed, 275 insertions(+), 235 deletions(-) Index: Python-3.6.15/Lib/test/test_ssl.py =================================================================== --- Python-3.6.15.orig/Lib/test/test_ssl.py +++ Python-3.6.15/Lib/test/test_ssl.py @@ -22,6 +22,7 @@ import weakref import platform import re import functools +import warnings import sysconfig try: import ctypes @@ -73,6 +74,7 @@ CRLFILE = data_file("revocation.crl") # Two keys and certs signed by the same CA (for SNI tests) SIGNED_CERTFILE = data_file("keycert3.pem") +SIGNED_CERTFILE_HOSTNAME = 'localhost' SIGNED_CERTFILE2 = data_file("keycert4.pem") # Same certificate as pycacert.pem, but without extra text in file SIGNING_CA = data_file("capath", "ceff1710.0") @@ -205,15 +207,17 @@ def skip_if_openssl_cnf_minprotocol_gt_t return f - needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test") def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *, cert_reqs=ssl.CERT_NONE, ca_certs=None, ciphers=None, certfile=None, keyfile=None, + check_hostname=None, **kwargs): context = ssl.SSLContext(ssl_version) + if check_hostname is not None: + context.check_hostname = check_hostname if cert_reqs is not None: context.verify_mode = cert_reqs if ca_certs is not None: @@ -224,6 +228,30 @@ def test_wrap_socket(sock, ssl_version=s context.set_ciphers(ciphers) return context.wrap_socket(sock, **kwargs) +def testing_context(server_cert=SIGNED_CERTFILE, *, server_chain=True): + """Create context + + client_context, server_context, hostname = testing_context() + """ + if server_cert == SIGNED_CERTFILE: + hostname = SIGNED_CERTFILE_HOSTNAME + elif server_cert == SIGNED_CERTFILE2: + hostname = SIGNED_CERTFILE2_HOSTNAME + elif server_cert == NOSANFILE: + hostname = NOSAN_HOSTNAME + else: + raise ValueError(server_cert) + + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_verify_locations(SIGNING_CA) + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_context.load_cert_chain(server_cert) + if server_chain: + server_context.load_verify_locations(SIGNING_CA) + + return client_context, server_context, hostname + class BasicSocketTests(unittest.TestCase): def test_constants(self): @@ -466,7 +494,7 @@ class BasicSocketTests(unittest.TestCase self.assertTrue(s.startswith("LibreSSL {:d}".format(major)), (s, t, hex(n))) else: - self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), + self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, patch)), (s, t, hex(n))) @support.cpython_only @@ -545,7 +573,8 @@ class BasicSocketTests(unittest.TestCase with self.assertRaises(ssl.SSLError): test_wrap_socket(sock, certfile=certfile, - ssl_version=ssl.PROTOCOL_TLSv1) + check_hostname=False, + ssl_version=ssl.PROTOCOL_TLS_CLIENT) def test_empty_cert(self): """Wrapping with an empty cert file""" @@ -974,7 +1003,7 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.protocol, proto) def test_ciphers(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_ciphers("ALL") ctx.set_ciphers("DEFAULT") with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"): @@ -995,7 +1024,7 @@ class ContextTests(unittest.TestCase): @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old') def test_get_ciphers(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_ciphers('AESGCM') names = set(d['name'] for d in ctx.get_ciphers()) self.assertIn('AES256-GCM-SHA384', names) @@ -1027,6 +1056,7 @@ class ContextTests(unittest.TestCase): def test_verify_mode(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.check_hostname = False # Default value self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) ctx.verify_mode = ssl.CERT_OPTIONAL @@ -1043,7 +1073,7 @@ class ContextTests(unittest.TestCase): @unittest.skipUnless(have_verify_flags(), "verify_flags need OpenSSL > 0.9.8") def test_verify_flags(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # default value tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf) @@ -1061,7 +1091,7 @@ class ContextTests(unittest.TestCase): ctx.verify_flags = None def test_load_cert_chain(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # Combined key and cert in a single file ctx.load_cert_chain(CERTFILE, keyfile=None) ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE) @@ -1074,7 +1104,7 @@ class ContextTests(unittest.TestCase): with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): ctx.load_cert_chain(EMPTYCERT) # Separate key and cert - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_cert_chain(ONLYCERT, ONLYKEY) ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY) ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY) @@ -1085,7 +1115,7 @@ class ContextTests(unittest.TestCase): with self.assertRaisesRegex(ssl.SSLError, "PEM lib"): ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT) # Mismatching key and cert - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"): ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY) # Password protected key and cert @@ -1144,7 +1174,7 @@ class ContextTests(unittest.TestCase): ctx.load_cert_chain(CERTFILE, password=getpass_exception) def test_load_verify_locations(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(CERTFILE) ctx.load_verify_locations(cafile=CERTFILE, capath=None) ctx.load_verify_locations(BYTES_CERTFILE) @@ -1172,7 +1202,7 @@ class ContextTests(unittest.TestCase): neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem) # test PEM - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0) ctx.load_verify_locations(cadata=cacert_pem) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1) @@ -1183,20 +1213,20 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # combined - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = "\n".join((cacert_pem, neuronio_pem)) ctx.load_verify_locations(cadata=combined) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # with junk around the certs - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = ["head", cacert_pem, "other", neuronio_pem, "again", neuronio_pem, "tail"] ctx.load_verify_locations(cadata="\n".join(combined)) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # test DER - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_verify_locations(cadata=cacert_der) ctx.load_verify_locations(cadata=neuronio_der) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) @@ -1205,13 +1235,13 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # combined - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) combined = b"".join((cacert_der, neuronio_der)) ctx.load_verify_locations(cadata=combined) self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2) # error cases - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object) with self.assertRaisesRegex( @@ -1229,7 +1259,7 @@ class ContextTests(unittest.TestCase): def test_load_dh_params(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_dh_params(DHFILE) if os.name != 'nt': ctx.load_dh_params(BYTES_DHFILE) @@ -1262,12 +1292,12 @@ class ContextTests(unittest.TestCase): def test_set_default_verify_paths(self): # There's not much we can do to test that it acts as expected, # so just check it doesn't crash or raise an exception. - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_default_verify_paths() @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build") def test_set_ecdh_curve(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.set_ecdh_curve("prime256v1") ctx.set_ecdh_curve(b"prime256v1") self.assertRaises(TypeError, ctx.set_ecdh_curve) @@ -1277,7 +1307,8 @@ class ContextTests(unittest.TestCase): @needs_sni def test_sni_callback(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.check_hostname = False # set_servername_callback expects a callable, or None self.assertRaises(TypeError, ctx.set_servername_callback) @@ -1294,7 +1325,7 @@ class ContextTests(unittest.TestCase): def test_sni_callback_refcycle(self): # Reference cycles through the servername callback are detected # and cleared. - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) def dummycallback(sock, servername, ctx, cycle=ctx): pass ctx.set_servername_callback(dummycallback) @@ -1304,7 +1335,7 @@ class ContextTests(unittest.TestCase): self.assertIs(wr(), None) def test_cert_store_stats(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.cert_store_stats(), {'x509_ca': 0, 'crl': 0, 'x509': 0}) ctx.load_cert_chain(CERTFILE) @@ -1318,7 +1349,7 @@ class ContextTests(unittest.TestCase): {'x509_ca': 1, 'crl': 0, 'x509': 2}) def test_get_ca_certs(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.get_ca_certs(), []) # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE ctx.load_verify_locations(CERTFILE) @@ -1346,24 +1377,24 @@ class ContextTests(unittest.TestCase): self.assertEqual(ctx.get_ca_certs(True), [der]) def test_load_default_certs(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs(ssl.Purpose.SERVER_AUTH) ctx.load_default_certs() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertRaises(TypeError, ctx.load_default_certs, None) self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH') @unittest.skipIf(sys.platform == "win32", "not-Windows specific") @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars") def test_load_default_certs_env(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with support.EnvironmentVarGuard() as env: env["SSL_CERT_DIR"] = CAPATH env["SSL_CERT_FILE"] = CERTFILE @@ -1372,11 +1403,11 @@ class ContextTests(unittest.TestCase): @unittest.skipUnless(sys.platform == "win32", "Windows specific") def test_load_default_certs_env_windows(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.load_default_certs() stats = ctx.cert_store_stats() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with support.EnvironmentVarGuard() as env: env["SSL_CERT_DIR"] = CAPATH env["SSL_CERT_FILE"] = CERTFILE @@ -1423,20 +1454,20 @@ class ContextTests(unittest.TestCase): def test__create_stdlib_context(self): ctx = ssl._create_stdlib_context() - self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS) self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) self.assertFalse(ctx.check_hostname) self._assert_context_options(ctx) - ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1) - self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) - self.assertEqual(ctx.verify_mode, ssl.CERT_NONE) + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) self._assert_context_options(ctx) - ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1, + ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLS_CLIENT, cert_reqs=ssl.CERT_REQUIRED, check_hostname=True) - self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1) + self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) self.assertTrue(ctx.check_hostname) self._assert_context_options(ctx) @@ -1447,7 +1478,7 @@ class ContextTests(unittest.TestCase): self._assert_context_options(ctx) def test_check_hostname(self): - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS) self.assertFalse(ctx.check_hostname) # Requires CERT_REQUIRED or CERT_OPTIONAL @@ -1494,7 +1525,7 @@ class SSLErrorTests(unittest.TestCase): def test_lib_reason(self): # Test the library and reason attributes - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) with self.assertRaises(ssl.SSLError) as cm: ctx.load_dh_params(CERTFILE) self.assertEqual(cm.exception.library, 'PEM') @@ -1505,7 +1536,9 @@ class SSLErrorTests(unittest.TestCase): def test_subclass(self): # Check that the appropriate SSLError subclass is raised # (this only tests one of them) - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE with socket.socket() as s: s.bind(("127.0.0.1", 0)) s.listen() @@ -1657,7 +1690,9 @@ class SimpleBackgroundTests(unittest.Tes def test_connect_with_context(self): # Same as test_connect, but with a separately created context - ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s: s.connect(self.server_addr) self.assertEqual({}, s.getpeercert()) @@ -1805,10 +1840,12 @@ class SimpleBackgroundTests(unittest.Tes @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), "needs TLS 1.2") def test_context_setget(self): # Check that the context of a connected socket can be replaced. - ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx1.load_verify_locations(capath=CAPATH) + ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx2.load_verify_locations(capath=CAPATH) s = socket.socket(socket.AF_INET) - with ctx1.wrap_socket(s) as ss: + with ctx1.wrap_socket(s, server_hostname='localhost') as ss: ss.connect(self.server_addr) self.assertIs(ss.context, ctx1) self.assertIs(ss._sslobj.context, ctx1) @@ -2136,24 +2173,26 @@ if _have_threads: certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, npn_protocols=None, alpn_protocols=None, - ciphers=None, context=None): + ciphers=None, context=None, check_hostname=None): if context: self.context = context else: self.context = ssl.SSLContext(ssl_version if ssl_version is not None else ssl.PROTOCOL_TLS) + if check_hostname is not None: + self.context.check_hostname = check_hostname self.context.verify_mode = (certreqs if certreqs is not None else ssl.CERT_NONE) - if cacerts: + if cacerts is not None: self.context.load_verify_locations(cacerts) - if certificate: + if certificate is not None: self.context.load_cert_chain(certificate) - if npn_protocols: + if npn_protocols is not None: self.context.set_npn_protocols(npn_protocols) - if alpn_protocols: + if alpn_protocols is not None: self.context.set_alpn_protocols(alpn_protocols) - if ciphers: + if ciphers is not None: self.context.set_ciphers(ciphers) self.chatty = chatty self.connectionchatty = connectionchatty @@ -2332,7 +2371,7 @@ if _have_threads: def server_params_test(client_context, server_context, indata=b"FOO\n", chatty=True, connectionchatty=False, sni_name=None, - session=None): + session=None, check_hostname=None): """ Launch a server, connect a client to it and try various reads and writes. @@ -2340,7 +2379,8 @@ if _have_threads: stats = {} server = ThreadedEchoServer(context=server_context, chatty=chatty, - connectionchatty=False) + connectionchatty=False, + check_hostname=check_hostname) with server: with client_context.wrap_socket(socket.socket(), server_hostname=sni_name, session=session) as s: @@ -2402,18 +2442,24 @@ if _have_threads: (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol), certtype)) - client_context = ssl.SSLContext(client_protocol) - client_context.options |= client_options - server_context = ssl.SSLContext(server_protocol) - server_context.options |= server_options + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + client_context = ssl.SSLContext(client_protocol) + client_context.options |= client_options + server_context = ssl.SSLContext(server_protocol) + server_context.options |= server_options + if certsreqs == ssl.CERT_NONE: + server_context.check_hostname = False # NOTE: we must enable "ALL" ciphers on the client, otherwise an # SSLv23 client will send an SSLv3 hello (rather than SSLv2) # starting from OpenSSL 1.0.0 (see issue #8322). - if client_context.protocol == ssl.PROTOCOL_SSLv23: + if client_context.protocol == ssl.PROTOCOL_TLS: client_context.set_ciphers("ALL") for ctx in (client_context, server_context): + ctx.check_hostname = False ctx.verify_mode = certsreqs ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) @@ -2447,26 +2493,14 @@ if _have_threads: """Basic test of an SSL client connecting to a server""" if support.verbose: sys.stdout.write("\n") - for protocol in PROTOCOLS: - if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: - continue - with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - client_context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - # server_context.load_verify_locations(SIGNING_CA) - server_context.load_cert_chain(SIGNED_CERTFILE2) + client_context, server_context, hostname = testing_context() with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): server_params_test(client_context=client_context, server_context=server_context, chatty=True, connectionchatty=True, - sni_name='fakehostname') + sni_name=hostname) client_context.check_hostname = False with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): @@ -2474,9 +2508,8 @@ if _have_threads: server_params_test(client_context=server_context, server_context=client_context, chatty=True, connectionchatty=True, - sni_name='fakehostname') - self.assertIn('called a function you should not call', - str(e.exception)) + sni_name=hostname) + self.assertIn('called a function you should not call', str(e.exception)) with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): with self.assertRaises(ssl.SSLError) as e: @@ -2539,39 +2572,37 @@ if _have_threads: if support.verbose: sys.stdout.write("\n") - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(SIGNING_CA) + client_context, server_context, hostname = testing_context() tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) - self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) + self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf) # VERIFY_DEFAULT should pass server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails - context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: with self.assertRaisesRegex(ssl.SSLError, "certificate verify failed"): s.connect((HOST, server.port)) # now load a CRL file. The CRL file is signed by the CA. - context.load_verify_locations(CRLFILE) + client_context.load_verify_locations(CRLFILE) server = ThreadedEchoServer(context=server_context, chatty=True) with server: - with context.wrap_socket(socket.socket()) as s: + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: s.connect((HOST, server.port)) cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") @@ -2580,10 +2611,10 @@ if _have_threads: if support.verbose: sys.stdout.write("\n") - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(SIGNED_CERTFILE) - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = True context.load_verify_locations(SIGNING_CA) @@ -2619,17 +2650,19 @@ if _have_threads: ) def test_hostname_checks_common_name(self): client_context, server_context, hostname = testing_context() - assert client_context.hostname_checks_common_name + if not hasattr(client_context, 'hostname_checks_common_name'): + raise unittest.SkipTest('test requires SSLContext having hostname_checks_common_name function.') + client_context.hostname_checks_common_name = False - # default cert has a SAN + # default cert has a SAN server = ThreadedEchoServer(context=server_context, chatty=True) with server: with client_context.wrap_socket(socket.socket(), server_hostname=hostname) as s: s.connect((HOST, server.port)) - client_context, server_context, hostname = testing_context(NOSANFILE) + client_context, server_context, hostname = testing_context(NOSANFILE) client_context.hostname_checks_common_name = False server = ThreadedEchoServer(context=server_context, chatty=True) with server: @@ -2717,29 +2750,31 @@ if _have_threads: "OpenSSL is compiled without SSLv2 support") def test_protocol_sslv2(self): """Connecting to an SSLv2 server with various client options""" + hostname='localhost' if support.verbose: sys.stdout.write("\n") try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False) # SSLv23 client with specific SSL options if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, client_options=ssl.OP_NO_SSLv2) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS_CLIENT, False, client_options=ssl.OP_NO_TLSv1) @skip_if_broken_ubuntu_ssl @skip_if_openssl_cnf_minprotocol_gt_tls1 def test_protocol_sslv23(self): """Connecting to an SSLv23 server with various client options""" + hostname='localhost' if support.verbose: sys.stdout.write("\n") if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2751,20 +2786,23 @@ if _have_threads: sys.stdout.write( " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) + + + if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3') if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_OPTIONAL) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLS_CLIENT, 'TLSv1.3', ssl.CERT_REQUIRED) # Server with specific SSL options if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2774,7 +2812,7 @@ if _have_threads: try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, - server_options=ssl.OP_NO_TLSv1) + server_options=ssl.OP_NO_TLSv1_3) @skip_if_broken_ubuntu_ssl @@ -2791,13 +2829,15 @@ if _have_threads: try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS_CLIENT, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") def test_protocol_tlsv1(self): """Connecting to a TLSv1 server with various client options""" if support.verbose: @@ -2809,7 +2849,7 @@ if _have_threads: try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1) @skip_if_broken_ubuntu_ssl @@ -2826,12 +2866,12 @@ if _have_threads: try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @skip_if_broken_ubuntu_ssl @@ -2849,24 +2889,27 @@ if _have_threads: try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') + if hasattr(ssl, "PROTOCOL_TLSv1"): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + if hasattr(ssl, "PROTOCOL_TLSv1_1"): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) def test_starttls(self): """Switching from clear text to encrypted and back again.""" msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") server = ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, starttls_server=True, chatty=True, - connectionchatty=True) + connectionchatty=True, + check_hostname=False) wrapped = False with server: s = socket.socket() @@ -2891,7 +2934,8 @@ if _have_threads: sys.stdout.write( " client: read %r from server, starting TLS...\n" % msg) - conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLS_CLIENT, + check_hostname=False) wrapped = True elif indata == b"ENDTLS" and msg.startswith(b"ok"): # ENDTLS ok, switch back to clear text @@ -2918,17 +2962,17 @@ if _have_threads: def test_socketserver(self): """Using socketserver to create and manage SSL connections.""" - server = make_https_server(self, certfile=CERTFILE) + server = make_https_server(self, certfile=SIGNED_CERTFILE) # try to connect if support.verbose: sys.stdout.write('\n') - with open(CERTFILE, 'rb') as f: + with open(__file__, 'rb') as f: d1 = f.read() d2 = '' # now fetch the same data from the HTTPS server url = 'https://localhost:%d/%s' % ( - server.port, os.path.split(CERTFILE)[1]) - context = ssl.create_default_context(cafile=CERTFILE) + server.port, os.path.split(__file__)[1]) + context = ssl.create_default_context(cafile=SIGNING_CA) f = urllib.request.urlopen(url, context=context) try: dlen = f.info().get("content-length") @@ -2978,7 +3022,7 @@ if _have_threads: server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, cacerts=CERTFILE, chatty=True, connectionchatty=False) @@ -2987,8 +3031,7 @@ if _have_threads: server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + cert_reqs=ssl.CERT_NONE) s.connect((HOST, server.port)) # helper methods for standardising recv* method signatures def _recv_into(): @@ -3130,32 +3173,30 @@ if _have_threads: def test_nonblocking_send(self): server = ThreadedEchoServer(CERTFILE, certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, cacerts=CERTFILE, chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), + with test_wrap_socket(socket.socket(), server_side=False, certfile=CERTFILE, ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - s.setblocking(False) + cert_reqs=ssl.CERT_NONE) as s: + s.connect((HOST, server.port)) + s.setblocking(False) - # If we keep sending data, at some point the buffers - # will be full and the call will block - buf = bytearray(8192) - def fill_buffer(): - while True: - s.send(buf) - self.assertRaises((ssl.SSLWantWriteError, - ssl.SSLWantReadError), fill_buffer) + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) - # Now read all the output and discard it - s.setblocking(True) - s.close() + # Now read all the output and discard it + s.setblocking(True) def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout @@ -3286,14 +3327,17 @@ if _have_threads: Basic tests for SSLSocket.version(). More tests are done in the test_protocol_*() methods. """ - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, + ssl_version=ssl.PROTOCOL_TLS_SERVER, chatty=False) as server: with context.wrap_socket(socket.socket()) as s: self.assertIs(s.version(), None) s.connect((HOST, server.port)) - self.assertEqual(s.version(), 'TLSv1') + self.assertEqual(s.version(), 'TLSv1.3') + self.assertIs(s._sslobj, None) self.assertIs(s.version(), None) @unittest.skipUnless(ssl.HAS_TLSv1_3, @@ -3306,7 +3350,7 @@ if _have_threads: ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2 ) with ThreadedEchoServer(context=context) as server: - with context.wrap_socket(socket.socket()) as s: + with context.wrap_socket(socket.socket(), server_hostname='localhost') as s: s.connect((HOST, server.port)) self.assertIn(s.cipher()[0], [ 'TLS_AES_256_GCM_SHA384', @@ -3342,64 +3386,59 @@ if _have_threads: if support.verbose: sys.stdout.write("\n") - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, + client_context, server_context, hostname = testing_context() + + server = ThreadedEchoServer(context=server_context, chatty=True, connectionchatty=False) with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - # get the data - cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got channel binding data: {0!r}\n" - .format(cb_data)) + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + # get the data + cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got channel binding data: {0!r}\n" + .format(cb_data)) - # check if it is sane - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - - # and compare with the peers version - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(cb_data).encode("us-ascii")) - s.close() + # check if it is sane + self.assertIsNotNone(cb_data) + if s.version() == 'TLSv1.3': + self.assertEqual(len(cb_data), 48) + else: + self.assertEqual(len(cb_data), 12) # True for TLSv1 + + # and compare with the peers version + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(cb_data).encode("us-ascii")) # now, again - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - new_cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got another channel binding data: {0!r}\n" - .format(new_cb_data)) - # is it really unique - self.assertNotEqual(cb_data, new_cb_data) - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(new_cb_data).encode("us-ascii")) - s.close() + with client_context.wrap_socket(socket.socket(), + server_hostname=hostname) as s: + s.connect((HOST, server.port)) + new_cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got another channel binding data: {0!r}\n" + .format(new_cb_data)) + # is it really unique + self.assertNotEqual(cb_data, new_cb_data) + self.assertIsNotNone(cb_data) + if s.version() == 'TLSv1.3': + self.assertEqual(len(cb_data), 48) + else: + self.assertEqual(len(cb_data), 12) # True for TLSv1 + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(new_cb_data).encode("us-ascii")) def test_compression(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname, check_hostname=False) if support.verbose: sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) @@ -3407,44 +3446,45 @@ if _have_threads: @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), "ssl.OP_NO_COMPRESSION needed for this test") def test_compression_disabled(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.options |= ssl.OP_NO_COMPRESSION - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + client_context.options |= ssl.OP_NO_COMPRESSION + server_context.options |= ssl.OP_NO_COMPRESSION + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname, check_hostname=False) self.assertIs(stats['compression'], None) def test_dh_params(self): # Check we can get a connection with ephemeral Diffie-Hellman - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.load_dh_params(DHFILE) - context.set_ciphers("kEDH") - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + client_context.options |= ssl.OP_NO_TLSv1_3 + client_context.load_dh_params(DHFILE) + client_context.set_ciphers("kEDH") + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) cipher = stats["cipher"][0] parts = cipher.split("-") if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: - self.fail("Non-DH cipher: " + cipher[0]) + self.fail(f"Non-DH cipher: {cipher[0]} (parts: {parts})") + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") def test_selected_alpn_protocol(self): # selected_alpn_protocol() is None unless ALPN is used. - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) + client_context, server_context, hostname = testing_context() + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['client_alpn_protocol'], None) @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") def test_selected_alpn_protocol_if_server_uses_alpn(self): # selected_alpn_protocol() is None unless ALPN is used by the client. - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_verify_locations(CERTFILE) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(CERTFILE) + client_context, server_context, hostname = testing_context() server_context.set_alpn_protocols(['foo', 'bar']) stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) + chatty=True, connectionchatty=True, + sni_name=hostname) self.assertIs(stats['client_alpn_protocol'], None) @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") @@ -3458,10 +3498,10 @@ if _have_threads: (['http/3.0', 'http/4.0'], None) ] for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(CERTFILE) server_context.set_alpn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.load_cert_chain(CERTFILE) client_context.set_alpn_protocols(client_protocols) @@ -3490,9 +3530,10 @@ if _have_threads: self.assertEqual(server_result, expected, msg % (server_result, "server")) + @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") def test_selected_npn_protocol(self): # selected_npn_protocol() is None unless NPN is used - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.load_cert_chain(CERTFILE) stats = server_params_test(context, context, chatty=True, connectionchatty=True) @@ -3508,10 +3549,10 @@ if _have_threads: (['abc', 'def'], 'abc') ] for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(CERTFILE) server_context.set_npn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.load_cert_chain(CERTFILE) client_context.set_npn_protocols(client_protocols) stats = server_params_test(client_context, server_context, @@ -3528,12 +3569,11 @@ if _have_threads: self.assertEqual(server_result, expected, msg % (server_result, "server")) def sni_contexts(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(SIGNED_CERTFILE) - other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) other_context.load_cert_chain(SIGNED_CERTFILE2) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.load_verify_locations(SIGNING_CA) return server_context, other_context, client_context @@ -3567,7 +3607,7 @@ if _have_threads: chatty=True, sni_name=None) self.assertEqual(calls, [(None, server_context)]) - self.check_common_name(stats, 'localhost') + self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) # Check disabling the callback calls = [] @@ -3577,7 +3617,7 @@ if _have_threads: chatty=True, sni_name='notfunny') # Certificate didn't change - self.check_common_name(stats, 'localhost') + self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME) self.assertEqual(calls, []) @skip_if_OpenSSL30 @@ -3635,9 +3675,9 @@ if _have_threads: @unittest.skipIf(IS_OPENSSL_1_1_1, "bpo-36576: fail on OpenSSL 1.1.1") def test_shared_ciphers(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.verify_mode = ssl.CERT_REQUIRED client_context.load_verify_locations(SIGNING_CA) if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): @@ -3697,14 +3737,11 @@ if _have_threads: self.assertEqual(s.recv(1024), TEST_DATA) def test_session(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) + client_context, server_context, hostname = testing_context() # first connection without session - stats = server_params_test(client_context, server_context) + stats = server_params_test(client_context, server_context, + sni_name=hostname) session = stats['session'] self.assertTrue(session.id) self.assertGreater(session.time, 0) @@ -3718,7 +3755,8 @@ if _have_threads: self.assertEqual(sess_stat['hits'], 0) # reuse session - stats = server_params_test(client_context, server_context, session=session) + stats = server_params_test(client_context, server_context, + session=session, sni_name=hostname) sess_stat = server_context.session_stats() self.assertEqual(sess_stat['accept'], 2) self.assertEqual(sess_stat['hits'], 1) @@ -3731,7 +3769,8 @@ if _have_threads: self.assertGreaterEqual(session2.timeout, session.timeout) # another one without session - stats = server_params_test(client_context, server_context) + stats = server_params_test(client_context, server_context, + sni_name=hostname) self.assertFalse(stats['session_reused']) session3 = stats['session'] self.assertNotEqual(session3.id, session.id) @@ -3741,7 +3780,8 @@ if _have_threads: self.assertEqual(sess_stat['hits'], 1) # reuse session again - stats = server_params_test(client_context, server_context, session=session) + stats = server_params_test(client_context, server_context, session=session, + sni_name=hostname) self.assertTrue(stats['session_reused']) session4 = stats['session'] self.assertEqual(session4.id, session.id)