From 21b753e3ef9742ae2190dcb4fcebcb9ae4e82512 Mon Sep 17 00:00:00 2001 From: julianz- <6255571+julianz-@users.noreply.github.com> Date: Mon, 17 Nov 2025 23:11:39 -0800 Subject: [PATCH] Introduce TLSSocket abstraction for uniform handling Introduces a new `TLSSocket` class to act as a unified wrapper for SSL/TLS connections, regardless of the underlying adapter (`builtin`, `pyOpenSSL`). This refactoring aims to: 1. Simplify adapter logic by centralizing common TLS socket properties and methods (e.g., cipher details, certificate paths). 2. Improve consistency when populating WSGI environment variables. 3. Centralize error handling in the adapters. --- .flake8 | 17 +- .mypy.ini | 2 +- cheroot/connections.py | 81 +-- cheroot/makefile.py | 33 +- cheroot/makefile.pyi | 6 +- cheroot/server.py | 11 +- cheroot/server.pyi | 2 +- cheroot/ssl/__init__.py | 229 ++++++- cheroot/ssl/__init__.pyi | 28 +- cheroot/ssl/builtin.py | 667 ++++++++++--------- cheroot/ssl/builtin.pyi | 5 +- cheroot/ssl/pyopenssl.py | 510 +++++++-------- cheroot/ssl/pyopenssl.pyi | 21 +- cheroot/ssl/tls_socket.py | 492 ++++++++++++++ cheroot/test/conftest.py | 100 ++- cheroot/test/ssl/test_ssl_builtin.py | 443 +++++++++++++ cheroot/test/ssl/test_ssl_pyopenssl.py | 703 +++++++++++++++++++++ cheroot/test/test_makefile.py | 54 +- cheroot/test/test_ssl.py | 66 +- cheroot/wsgi.py | 2 + docs/changelog-fragments.d/799.contrib.rst | 3 + docs/spelling_wordlist.txt | 4 + stubtest_allowlist.txt | 37 -- 23 files changed, 2705 insertions(+), 811 deletions(-) create mode 100644 cheroot/ssl/tls_socket.py create mode 100644 cheroot/test/ssl/test_ssl_builtin.py create mode 100644 cheroot/test/ssl/test_ssl_pyopenssl.py create mode 100644 docs/changelog-fragments.d/799.contrib.rst diff --git a/.flake8 b/.flake8 index aca9b2fffe..aa2bf47fe3 100644 --- a/.flake8 +++ b/.flake8 @@ -124,16 +124,17 @@ per-file-ignores = cheroot/__main__.py: WPS130 cheroot/_compat.py: DAR101, DAR201, DAR301, DAR401, I003, RST304, WPS100, WPS111, WPS123, WPS226, WPS229, WPS420, WPS422, WPS432, WPS504, WPS505 cheroot/cli.py: DAR101, DAR201, DAR401, I001, I004, I005, WPS100, WPS110, WPS120, WPS130, WPS202, WPS226, WPS229, WPS338, WPS420, WPS421 - cheroot/connections.py: DAR101, DAR201, DAR301, DAR401, I001, I003, I004, I005, RST304, S104, WPS100, WPS110, WPS111, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS212, WPS214, WPS220, WPS229, WPS231, WPS301, WPS324, WPS338, WPS420, WPS421, WPS422, WPS432, WPS501, WPS504, WPS505 + cheroot/connections.py: DAR101, DAR201, DAR301, DAR401, I001, I003, I004, I005, RST304, S104, WPS100, WPS110, WPS111, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS212, WPS214, WPS220, WPS229, WPS231, WPS237, WPS301, WPS324, WPS338, WPS420, WPS421, WPS422, WPS432, WPS501, WPS504, WPS505 cheroot/errors.py: DAR101, DAR201, I003, RST304, WPS111, WPS121, WPS422 - cheroot/makefile.py: DAR101, DAR201, DAR401, E800, I003, I004, N801, N802, S101, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS123, WPS130, WPS204, WPS210, WPS212, WPS213, WPS220, WPS229, WPS231, WPS232, WPS338, WPS420, WPS422, WPS429, WPS431, WPS504, WPS604, WPS606 + cheroot/makefile.py: DAR101, DAR201, DAR401, E800, I003, I004, N801, N802, S101, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS123, WPS130, WPS204, WPS210, WPS212, WPS213, WPS226, WPS220, WPS229, WPS231, WPS232, WPS338, WPS420, WPS422, WPS429, WPS431, WPS504, WPS604, WPS606 cheroot/server.py: DAR003, DAR101, DAR201, DAR202, DAR301, DAR401, E800, I001, I003, I004, I005, N806, RST201, RST301, RST303, RST304, WPS100, WPS110, WPS111, WPS115, WPS120, WPS121, WPS122, WPS130, WPS132, WPS201, WPS202, WPS204, WPS210, WPS211, WPS212, WPS213, WPS214, WPS220, WPS221, WPS225, WPS226, WPS229, WPS230, WPS231, WPS236, WPS237, WPS238, WPS301, WPS338, WPS342, WPS410, WPS420, WPS421, WPS422, WPS429, WPS432, WPS504, WPS505, WPS601, WPS602, WPS608, WPS617 - cheroot/ssl/builtin.py: DAR101, DAR201, DAR401, I001, I003, N806, RST304, WPS110, WPS111, WPS115, WPS117, WPS120, WPS121, WPS122, WPS130, WPS201, WPS210, WPS214, WPS229, WPS231, WPS338, WPS422, WPS501, WPS505, WPS529, WPS608, WPS612 - cheroot/ssl/pyopenssl.py: C815, DAR101, DAR201, DAR401, I001, I003, I005, N801, N804, RST304, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS130, WPS210, WPS220, WPS221, WPS225, WPS229, WPS231, WPS238, WPS301, WPS335, WPS338, WPS420, WPS422, WPS430, WPS432, WPS501, WPS504, WPS505, WPS601, WPS608, WPS615 - cheroot/test/conftest.py: DAR101, DAR201, DAR301, I001, I003, I005, WPS100, WPS130, WPS325, WPS354, WPS420, WPS422, WPS430, WPS457 + cheroot/ssl/builtin.py: DAR101, DAR201, DAR401, I001, I003, N806, RST304, WPS110, WPS111, WPS115, WPS117, WPS120, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS220, WPS214, WPS226, WPS229, WPS231, WPS338, WPS421, WPS422, WPS501, WPS505, WPS529, WPS608, WPS612 + cheroot/ssl/pyopenssl.py: C815, DAR101, DAR201, DAR401, I001, I003, I005, N801, N804, RST304, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS130, WPS210, WPS220, WPS221, WPS225, WPS229, WPS231, WPS238, WPS301, WPS335, WPS338, WPS420, WPS422, WPS430, WPS432, WPS501, WPS504, WPS505, WPS601, WPS608, WPS615 + cheroot/ssl/tls_socket.py: DAR101, DAR201, DAR401, WPS110, WPS122, WPS210, WPS212, WPS214, WPS220, WPS225, WPS226, WPS229, WPS231, WPS238, WPS338, WPS362, WPS407 + cheroot/test/conftest.py: DAR101, DAR201, DAR301, I001, I003, I005, WPS100, WPS130, WPS202, WPS325, WPS354, WPS420, WPS422, WPS430, WPS457 cheroot/test/helper.py: DAR101, DAR201, DAR401, I001, I003, I004, N802, WPS110, WPS111, WPS121, WPS201, WPS220, WPS231, WPS301, WPS414, WPS421, WPS422, WPS505 cheroot/test/test_cli.py: DAR101, DAR201, I001, I005, N802, S101, S108, WPS110, WPS421, WPS431, WPS473 - cheroot/test/test_makefile.py: DAR101, DAR201, I004, RST304, S101, WPS110, WPS122 + cheroot/test/test_makefile.py: DAR101, DAR201, I004, RST304, S101, WPS110, WPS122, WPS362 cheroot/test/test_wsgi.py: DAR101, DAR301, I001, I004, S101, WPS110, WPS111, WPS117, WPS118, WPS121, WPS210, WPS421, WPS430, WPS432, WPS441, WPS509 cheroot/test/test_core.py: C815, DAR101, DAR201, DAR401, I003, I004, N805, N806, S101, WPS110, WPS111, WPS114, WPS121, WPS202, WPS204, WPS226, WPS229, WPS324, WPS421, WPS422, WPS432, WPS602 cheroot/test/test_dispatch.py: DAR101, DAR201, S101, WPS111, WPS121, WPS422, WPS430 @@ -144,7 +145,9 @@ per-file-ignores = cheroot/testing.py: C815, DAR101, DAR201, DAR301, I001, I003, S104, WPS100, WPS202, WPS211, WPS229, WPS301, WPS414, WPS420, WPS422, WPS430 cheroot/workers/threadpool.py: DAR101, DAR201, E800, I001, I003, I004, RST201, RST203, RST301, WPS100, WPS110, WPS111, WPS121, WPS122, WPS210, WPS211, WPS214, WPS220, WPS229, WPS230, WPS231, WPS335, WPS338, WPS362, WPS363, WPS410, WPS414, WPS420, WPS422, WPS432, WPS501, WPS505, WPS601, WPS602, WPS617 cheroot/wsgi.py: DAR101, DAR201, DAR401, I001, I003, I005, N801, RST201, RST301, WPS100, WPS110, WPS111, WPS114, WPS121, WPS122, WPS130, WPS210, WPS211, WPS226, WPS229, WPS231, WPS338, WPS420, WPS421, WPS422, WPS430, WPS501, WPS504, WPS602, WPS608 - cheroot/ssl/__init__.py: DAR101, DAR201, I003, WPS412, WPS422 + cheroot/ssl/__init__.py: DAR101, DAR201, I003, WPS210, WPS412, WPS422 + cheroot/test/ssl/test_ssl_builtin.py: DAR101, DAR201, I003, WPS118, WPS201, WPS202, WPS210, WPS213, WPS218, WPS211, WPS226, WPS229, WPS231, WPS243, WPS412, WPS420, WPS422, WPS430, WPS505 + cheroot/test/ssl/test_ssl_pyopenssl.py: DAR101, DAR201, I003, WPS118, WPS201, WPS202, WPS204, WPS210, WPS220, WPS213, WPS218, WPS211, WPS226, WPS229, WPS231, WPS243, WPS412, WPS420, WPS422, WPS430, WPS432, WPS435, WPS505 cheroot/test/_pytest_plugin.py: DAR101, I003, I004, WPS422 cheroot/test/test__compat.py: DAR101, I001, I003, I005, WPS116, WPS226, WPS422, S101 cheroot/test/test_errors.py: DAR101, WPS509, S101 diff --git a/.mypy.ini b/.mypy.ini index fa1389a44b..8e9badbc25 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.8 +python_version = 3.9 color_output = true error_summary = true files = diff --git a/cheroot/connections.py b/cheroot/connections.py index 04df0d42a3..8a66f799f0 100644 --- a/cheroot/connections.py +++ b/cheroot/connections.py @@ -1,6 +1,5 @@ """Utilities to manage open connections.""" -import io import os import selectors import socket @@ -10,7 +9,6 @@ from . import errors from ._compat import IS_WINDOWS -from .makefile import MakeFile try: @@ -293,50 +291,22 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME if hasattr(s, 'settimeout'): s.settimeout(self.server.timeout) - mf = MakeFile ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server if self.server.ssl_adapter is not None: try: s, ssl_env = self.server.ssl_adapter.wrap(s) - except errors.FatalSSLAlert as tls_connection_drop_error: - self.server.error_log( - f'Client {addr!s} lost — peer dropped the TLS ' - 'connection suddenly, during handshake: ' - f'{tls_connection_drop_error!s}', - ) - return None - except errors.NoSSLError as http_over_https_err: + except errors.FatalSSLAlert as tls_connection_error: self.server.error_log( - f'Client {addr!s} attempted to speak plain HTTP into ' - 'a TCP connection configured for TLS-only traffic — ' - 'trying to send back a plain HTTP error response: ' - f'{http_over_https_err!s}', + f'Failed to establish SSL connection with {addr!s}: ' + f'{tls_connection_error!s}', ) - msg = ( - 'The client sent a plain HTTP request, but ' - 'this server only speaks HTTPS on this port.' - ) - buf = [ - '%s 400 Bad Request\r\n' % self.server.protocol, - 'Content-Length: %s\r\n' % len(msg), - 'Content-Type: text/plain\r\n\r\n', - msg, - ] - - wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE) - try: - wfile.write(''.join(buf).encode('ISO-8859-1')) - except OSError as ex: - if ex.args[0] not in errors.socket_errors_to_ignore: - raise return None - mf = self.server.ssl_adapter.makefile - # Re-apply our timeout since we may have a new socket object - if hasattr(s, 'settimeout'): - s.settimeout(self.server.timeout) + except errors.NoSSLError: + return self._send_bad_request_plain_http_error(s, addr) - conn = self.server.ConnectionClass(self.server, s, mf) + conn = self.server.ConnectionClass(self.server, s) if not isinstance(self.server.bind_addr, (str, bytes)): # optional values @@ -381,6 +351,43 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME return None raise + def _send_bad_request_plain_http_error(self, sock, addr): + """Send Bad Request 400 response, and close the socket.""" + self.server.error_log( + f'Client {addr!s} attempted to speak plain HTTP into ' + 'a TCP connection configured for TLS-only traffic — ' + 'Sending 400 Bad Request.', + ) + + msg = ( + 'The client sent a plain HTTP request, but this server ' + 'only speaks HTTPS on this port.' + ) + + response_parts = [ + f'{self.server.protocol} 400 Bad Request\r\n', + 'Content-Type: text/plain\r\n', + f'Content-Length: {len(msg)}\r\n', + 'Connection: close\r\n', + '\r\n', + msg, + ] + response_bytes = ''.join(response_parts).encode('ISO-8859-1') + + try: + # Handle both raw sockets and SSL connections + if hasattr(sock, 'sendall'): + sock.sendall(response_bytes) + else: + # Fallback for older PyOpenSSL or SSL objects + sock.send(response_bytes) + sock.shutdown(socket.SHUT_WR) + except OSError as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + + sock.close() + def close(self): """Close all monitored connections.""" for _, conn in self._selector.connections: diff --git a/cheroot/makefile.py b/cheroot/makefile.py index f5780a1ede..b1981d34d5 100644 --- a/cheroot/makefile.py +++ b/cheroot/makefile.py @@ -2,6 +2,7 @@ # prefer slower Python-based io module import _pyio as io +import io as stdlib_io import socket @@ -38,9 +39,16 @@ def _flush_unlocked(self): class StreamReader(io.BufferedReader): """Socket stream reader.""" - def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): - """Initialize socket stream reader.""" - super().__init__(socket.SocketIO(sock, mode), bufsize) + def __init__(self, sock, bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize with socket or raw IO object.""" + # If already a RawIOBase (like TLSSocket), use directly + if isinstance(sock, (io.RawIOBase, stdlib_io.RawIOBase)): + raw_io = sock + else: + # Wrap raw socket with SocketIO + raw_io = socket.SocketIO(sock, 'rb') + + super().__init__(raw_io, bufsize) self.bytes_read = 0 def read(self, *args, **kwargs): @@ -57,9 +65,16 @@ def has_data(self): class StreamWriter(BufferedWriter): """Socket stream writer.""" - def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE): - """Initialize socket stream writer.""" - super().__init__(socket.SocketIO(sock, mode), bufsize) + def __init__(self, sock, bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize with socket or raw IO object.""" + # If already a RawIOBase (like TLSSocket), use directly + if isinstance(sock, (io.RawIOBase, stdlib_io.RawIOBase)): + raw_io = sock + else: + # Wrap raw socket with SocketIO + raw_io = socket.SocketIO(sock, 'wb') + + super().__init__(raw_io, bufsize) self.bytes_written = 0 def write(self, val, *args, **kwargs): @@ -67,9 +82,3 @@ def write(self, val, *args, **kwargs): res = super().write(val, *args, **kwargs) self.bytes_written += len(val) return res - - -def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): - """File object attached to a socket object.""" - cls = StreamReader if 'r' in mode else StreamWriter - return cls(sock, mode, bufsize) diff --git a/cheroot/makefile.pyi b/cheroot/makefile.pyi index 3f5ea2756b..592ef2bf91 100644 --- a/cheroot/makefile.pyi +++ b/cheroot/makefile.pyi @@ -7,13 +7,11 @@ class BufferedWriter(io.BufferedWriter): class StreamReader(io.BufferedReader): bytes_read: int - def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ... + def __init__(self, sock, bufsize=...) -> None: ... def read(self, *args, **kwargs): ... def has_data(self): ... class StreamWriter(BufferedWriter): bytes_written: int - def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ... + def __init__(self, sock, bufsize=...) -> None: ... def write(self, val, *args, **kwargs): ... - -def MakeFile(sock, mode: str = ..., bufsize=...): ... diff --git a/cheroot/server.py b/cheroot/server.py index 9a288539c7..cb79316554 100644 --- a/cheroot/server.py +++ b/cheroot/server.py @@ -83,7 +83,7 @@ from . import __version__, connections, errors from ._compat import IS_PPC, bton -from .makefile import MakeFile, StreamWriter +from .makefile import StreamReader, StreamWriter from .workers import threadpool @@ -1275,19 +1275,18 @@ class HTTPConnection: # Fields set by ConnectionManager. last_used = None - def __init__(self, server, sock, makefile=MakeFile): + def __init__(self, server, sock): """Initialize HTTPConnection instance. Args: server (HTTPServer): web server object receiving this request sock (socket._socketobject): the raw socket object (usually TCP) for this connection - makefile (file): a fileobject class for reading from the socket """ self.server = server self.socket = sock - self.rfile = makefile(sock, 'rb', self.rbufsize) - self.wfile = makefile(sock, 'wb', self.wbufsize) + self.rfile = StreamReader(sock, self.rbufsize) + self.wfile = StreamWriter(sock, self.wbufsize) self.requests_seen = 0 self.peercreds_enabled = self.server.peercreds_enabled @@ -1363,7 +1362,7 @@ def _handle_no_ssl(self, req): except AttributeError: # self.socket is of OpenSSL.SSL.Connection type resp_sock = self.socket._socket - self.wfile = StreamWriter(resp_sock, 'wb', self.wbufsize) + self.wfile = StreamWriter(resp_sock, self.wbufsize) msg = ( 'The client sent a plain HTTP request, but ' 'this server only speaks HTTPS on this port.' diff --git a/cheroot/server.pyi b/cheroot/server.pyi index c5c0f517f6..569ec94225 100644 --- a/cheroot/server.pyi +++ b/cheroot/server.pyi @@ -112,7 +112,7 @@ class HTTPConnection: rfile: Any wfile: Any requests_seen: int - def __init__(self, server, sock, makefile=...) -> None: ... + def __init__(self, server, sock) -> None: ... def communicate(self): ... linger: bool def close(self) -> None: ... diff --git a/cheroot/ssl/__init__.py b/cheroot/ssl/__init__.py index c0072c0557..ad4ff40f5e 100644 --- a/cheroot/ssl/__init__.py +++ b/cheroot/ssl/__init__.py @@ -1,16 +1,162 @@ -"""Implementation of the SSL adapter base interface.""" +"""Implementation of the SSL adapter base interface. + +.. spelling:: + dn +""" from abc import ABC, abstractmethod +from contextlib import suppress + + +DN_SEPARATOR = '/' # Distinguished Name separator +UTF8_ENCODING = 'utf-8' + + +def _parse_dn_components(components, key_prefix, dn_type): + """ + Parse Distinguished Name components into environ dict. + + Args: + components: Iterable of (key, value) tuples + key_prefix: 'SSL_CLIENT' or 'SSL_SERVER' + dn_type: 'S' for subject or 'I' for issuer + + Returns: + dict: ``DN`` and ``CN`` environment variables + """ + env = {} + dn_parts = [] + + for key, attr_value in components: + dn_parts.append(f'{key}={attr_value}') + if key in {'CN', 'commonName'}: + env[f'{key_prefix}_{dn_type}_DN_CN'] = attr_value + + if dn_parts: + dn_string = DN_SEPARATOR.join(dn_parts) + env[f'{key_prefix}_{dn_type}_DN'] = f'{DN_SEPARATOR}{dn_string}' + + return env + + +def parse_pyopenssl_cert_to_environ(cert, key_prefix): + """Parse a pyOpenSSL X509 certificate into WSGI environ dict.""" + env = {} + if not cert: + return env + + # Subject + subject = cert.get_subject() + if subject: + components = [ + (key.decode(UTF8_ENCODING), attr_value.decode(UTF8_ENCODING)) + for key, attr_value in subject.get_components() + ] + env.update(_parse_dn_components(components, key_prefix, 'S')) + + # Issuer + issuer = cert.get_issuer() + if issuer: + components = [ + (key.decode(UTF8_ENCODING), attr_value.decode(UTF8_ENCODING)) + for key, attr_value in issuer.get_components() + ] + env.update(_parse_dn_components(components, key_prefix, 'I')) + + # Version and Serial + env[f'{key_prefix}_M_VERSION'] = str(cert.get_version()) + env[f'{key_prefix}_M_SERIAL'] = str(cert.get_serial_number()) + + return env + + +def parse_x509_cert_to_environ(cert, key_prefix): + """Parse a cryptography x509 certificate into environ dict.""" + env = {} + + # Subject + with suppress(Exception): + subject = cert.subject + components = [(attr.oid._name, attr.value) for attr in subject] + env.update(_parse_dn_components(components, key_prefix, 'S')) + + # Issuer + with suppress(Exception): + issuer = cert.issuer + components = [(attr.oid._name, attr.value) for attr in issuer] + env.update(_parse_dn_components(components, key_prefix, 'I')) + + # Version and Serial + with suppress(Exception): + env[f'{key_prefix}_M_VERSION'] = str(cert.version.value) + env[f'{key_prefix}_M_SERIAL'] = str(cert.serial_number) + + return env + + +class SSLEnvironMixin: + """ + Mixin class providing methods for generating WSGI environment variables. + + This mixin handles GENERIC SSL environment variable generation that works + across all SSL implementations. Adapter-specific logic (like certificate + parsing) is delegated to subclass implementations. + """ + + def _get_core_tls_environ(self, conn): + """ + Add core TLS version and cipher info to the environment. + + This is generic and works for all SSL adapters since TLSSocket + provides a uniform get_cipher_info() interface. + """ + cipher_info = conn.get_cipher_info() + + # Early exit if no cipher info (not a secure connection) + if cipher_info is None: + return {'wsgi.url_scheme': 'http'} + + cipher_name, protocol, cipher_keysize = cipher_info + + return { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_PROTOCOL': protocol, + 'SSL_CIPHER': cipher_name, + 'SSL_CIPHER_EXPORT': '', + 'SSL_CIPHER_USEKEYSIZE': cipher_keysize, + 'SSL_CLIENT_VERIFY': 'NONE', + } + + def _get_server_cert_environ(self): + """ + Get server certificate info from the connection. + + MUST be overridden by subclasses to provide adapter-specific parsing. + Returns dict of SSL_SERVER_* environ variables. + + Default implementation returns empty dict. + """ + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """ + Add client certificate details to the environment. + + SHOULD be overridden by subclasses for adapter-specific handling. + Default implementation does nothing. + """ + return ssl_environ -class Adapter(ABC): +class Adapter(SSLEnvironMixin, ABC): """Base class for SSL driver library adapters. Required methods: * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` - * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> - socket file object`` + * ``_get_library_version_environ() -> dict`` + * ``_get_optional_environ(conn) -> dict`` """ @abstractmethod @@ -39,14 +185,75 @@ def bind(self, sock): @abstractmethod def wrap(self, sock): """Wrap and return the given socket, plus WSGI environ entries.""" - raise NotImplementedError # pragma: no cover + raise NotImplementedError + + def get_environ(self, conn): + """ + Return WSGI environ entries to be merged into each request. + + Unified implementation used by all subclasses. This orchestrates + the collection of SSL environment variables from various sources: + - Core TLS info (protocol, cipher) + - Library versions + - Optional fields (SNI, etc.) + - Session info + - Client certificate + - Server certificate + + Note: This returns only SSL-specific variables. General server + variables (``SERVER_NAME``, ``SERVER_PORT``, etc.) are added by + the Gateway when building the complete WSGI environ for each request. + """ + # 1. Handle basic TLS info + ssl_environ = self._get_core_tls_environ(conn) + if 'HTTPS' not in ssl_environ: + # Core TLS failed (returned 'http' env) + return ssl_environ + + # 2. Update with library-specific version strings + ssl_environ.update(self._get_library_version_environ()) + + # 3. Handle optional/platform-specific fields (SNI, compression) + ssl_environ.update(self._get_optional_environ(conn)) + + # 4. Handle Session ID + with suppress(AttributeError): + session = conn.get_session() + if session and hasattr(session, 'id'): + ssl_environ['SSL_SESSION_ID'] = session.id.hex() + + # 5. Handle Client certificate (adapter-specific) + ssl_environ = self._get_client_cert_environ(conn, ssl_environ) + + # 6. Server certificate (adapter-specific) + server_cert_info = self._get_server_cert_environ() + if server_cert_info: + ssl_environ.update(server_cert_info) + + return ssl_environ @abstractmethod - def get_environ(self): - """Return WSGI environ entries to be merged into each request.""" - raise NotImplementedError # pragma: no cover + def _get_library_version_environ(self): + """ + Get SSL library version information. + + Must be implemented by subclasses to provide adapter-specific + version strings. + + Returns: + dict: SSL_VERSION_INTERFACE and SSL_VERSION_LIBRARY + """ + raise NotImplementedError @abstractmethod - def makefile(self, sock, mode='r', bufsize=-1): - """Return socket file object.""" - raise NotImplementedError # pragma: no cover + def _get_optional_environ(self, conn): + """ + Get optional environment variables. + + Must be implemented by subclasses for adapter-specific handling + of optional fields like SNI, compression, etc. + + Returns: + dict: Optional SSL environment variables + """ + raise NotImplementedError diff --git a/cheroot/ssl/__init__.pyi b/cheroot/ssl/__init__.pyi index c595121546..2582632314 100644 --- a/cheroot/ssl/__init__.pyi +++ b/cheroot/ssl/__init__.pyi @@ -1,6 +1,30 @@ from abc import ABC, abstractmethod from typing import Any +DN_SEPARATOR: str +UTF8_ENCODING: str + +def parse_pyopenssl_cert_to_environ( + cert: Any, + key_prefix: str, +) -> dict[str, Any]: ... +def parse_x509_cert_to_environ( + cert: Any, + key_prefix: str, +) -> dict[str, Any]: ... + +class SSLEnvironMixin: + def _get_core_tls_environ( + self, + conn: Any, + ) -> dict[str, Any]: ... + def _get_server_cert_environ(self) -> dict[str, Any]: ... + def _get_client_cert_environ( + self, + conn: Any, + ssl_environ: dict[str, Any], + ) -> dict[str, Any]: ... + class Adapter(ABC): certificate: Any private_key: Any @@ -23,6 +47,4 @@ class Adapter(ABC): @abstractmethod def wrap(self, sock): ... @abstractmethod - def get_environ(self): ... - @abstractmethod - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict[str, Any]: ... diff --git a/cheroot/ssl/builtin.py b/cheroot/ssl/builtin.py index ed747ab6e3..f0ef58e66c 100644 --- a/cheroot/ssl/builtin.py +++ b/cheroot/ssl/builtin.py @@ -9,9 +9,11 @@ import socket import sys -import threading from contextlib import suppress +from . import Adapter +from .tls_socket import TLSSocket + try: import ssl @@ -19,148 +21,24 @@ ssl = None try: - from _pyio import DEFAULT_BUFFER_SIZE + from cryptography import x509 + from cryptography.hazmat.backends import default_backend except ImportError: - try: - from io import DEFAULT_BUFFER_SIZE - except ImportError: - DEFAULT_BUFFER_SIZE = -1 - -from .. import errors -from ..makefile import StreamReader, StreamWriter -from ..server import HTTPServer -from . import Adapter + x509 = None + default_backend = None -def _assert_ssl_exc_contains(exc, *msgs): - """Check whether SSL exception contains either of messages provided.""" - if len(msgs) < 1: - raise TypeError( - '_assert_ssl_exc_contains() requires ' - 'at least one message to be passed.', - ) - err_msg_lower = str(exc).lower() - return any(m.lower() in err_msg_lower for m in msgs) - - -def _loopback_for_cert_thread(context, server): - """Wrap a socket in ssl and perform the server-side handshake.""" - # As we only care about parsing the certificate, the failure of - # which will cause an exception in ``_loopback_for_cert``, - # we can safely ignore connection and ssl related exceptions. Ref: - # https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030 - with suppress(ssl.SSLError, OSError): - with context.wrap_socket( - server, - do_handshake_on_connect=True, - server_side=True, - ) as ssl_sock: - # in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server - # sends the client session tickets that can be used to - # resume the TLS session on a new connection without - # performing the full handshake again. session tickets are - # sent as a post-handshake message at some _unspecified_ - # time and thus a successful connection may be closed - # without the client having received the tickets. - # Unfortunately, on Windows (Python 3.8+), this is treated - # as an incomplete handshake on the server side and a - # ``ConnectionAbortedError`` is raised. - # TLS 1.3 support is still incomplete in Python 3.8; - # there is no way for the client to wait for tickets. - # While not necessary for retrieving the parsed certificate, - # we send a tiny bit of data over the connection in an - # attempt to give the server a chance to send the session - # tickets and close the connection cleanly. - # Note that, as this is essentially a race condition, - # the error may still occur ocasionally. - ssl_sock.send(b'0000') - - -def _loopback_for_cert( - certificate, - private_key, - certificate_chain, - *, - private_key_password=None, -): - """Create a loopback connection to parse a cert with a private key.""" - context = ssl.create_default_context(cafile=certificate_chain) - context.load_cert_chain( - certificate, - private_key, - password=private_key_password, - ) - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - - # Python 3+ Unix, Python 3.5+ Windows - client, server = socket.socketpair() - try: - # `wrap_socket` will block until the ssl handshake is complete. - # it must be called on both ends at the same time -> thread - # openssl will cache the peer's cert during a successful handshake - # and return it via `getpeercert` even after the socket is closed. - # when `close` is called, the SSL shutdown notice will be sent - # and then python will wait to receive the corollary shutdown. - thread = threading.Thread( - target=_loopback_for_cert_thread, - args=(context, server), - ) - try: - thread.start() - with context.wrap_socket( - client, - do_handshake_on_connect=True, - server_side=False, - ) as ssl_sock: - ssl_sock.recv(4) - return ssl_sock.getpeercert() - finally: - thread.join() - finally: - client.close() - server.close() - - -def _parse_cert( - certificate, - private_key, - certificate_chain, - *, - private_key_password=None, -): - """Parse a certificate.""" - # loopback_for_cert uses socket.socketpair which was only - # introduced in Python 3.0 for *nix and 3.5 for Windows - # and requires OS support (AttributeError, OSError) - # it also requires a private key either in its own file - # or combined with the cert (SSLError) - with suppress(AttributeError, ssl.SSLError, OSError): - return _loopback_for_cert( - certificate, - private_key, - certificate_chain, - private_key_password=private_key_password, - ) - - # KLUDGE: using an undocumented, private, test method to parse a cert - # unfortunately, it is the only built-in way without a connection - # as a private, undocumented method, it may change at any time - # so be tolerant of *any* possible errors it may raise - with suppress(Exception): - return ssl._ssl._test_decode_cert(certificate) - - return {} - - -def _sni_callback(sock, sni, context): - """Handle the SNI callback to tag the socket with the SNI.""" - sock.sni = sni - # return None to allow the TLS negotiation to continue +from .. import errors +from . import parse_x509_cert_to_environ class BuiltinSSLAdapter(Adapter): - """Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot.""" + """ + Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot. + + This adapter uses TLSSocket internally to provide a consistent + interface for SSL/TLS connections. + """ certificate = None """The file name of the server SSL certificate.""" @@ -235,11 +113,13 @@ def __init__( *, private_key_password=None, ): - """Set up context in addition to base class properties if available.""" + """Initialize builtin SSL Adapter instance.""" if ssl is None: - raise ImportError('You must install the ssl module to use HTTPS.') + raise ImportError( + 'You must have ssl module available to use HTTPS.', + ) - super(BuiltinSSLAdapter, self).__init__( + super().__init__( certificate, private_key, certificate_chain, @@ -247,162 +127,378 @@ def __init__( private_key_password=private_key_password, ) - self.context = ssl.create_default_context( - purpose=ssl.Purpose.CLIENT_AUTH, - cafile=certificate_chain, - ) - self.context.load_cert_chain( - certificate, - private_key, - password=private_key_password, - ) - if self.ciphers is not None: - self.context.set_ciphers(ciphers) - - self._server_env = self._make_env_cert_dict( - 'SSL_SERVER', - _parse_cert( - certificate, - private_key, - self.certificate_chain, - private_key_password=private_key_password, - ), - ) - if not self._server_env: - return - cert = None - with open(certificate) as f: - cert = f.read() - - # strip off any keys by only taking the first certificate - cert_start = cert.find(ssl.PEM_HEADER) - if cert_start == -1: - return - cert_end = cert.find(ssl.PEM_FOOTER, cert_start) - if cert_end == -1: - return - cert_end += len(ssl.PEM_FOOTER) - self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end] + self._context = None + self._context = self._create_context() + + def bind(self, sock): + """Prepare the server socket.""" + return sock # Context already created + + def wrap(self, sock): + """ + Wrap client socket with SSL and return environ entries. + + Args: + sock: Raw socket to wrap with TLS + + Returns: + tuple: (TLSSocket, ssl_environ_dict) + """ + if self._check_for_plain_http(sock): + raise errors.NoSSLError + + tls_socket = self._wrap_with_builtin(sock) + ssl_environ = self.get_environ(tls_socket) + return tls_socket, ssl_environ + + def _wrap_with_builtin(self, raw_socket, server_side=True): + """ + Create a TLSSocket using Python's built-in ssl module. + + Args: + raw_socket: The raw socket to wrap + server_side: True if this is the server side + + Returns: + TLSSocket: Wrapped socket ready for secure I/O + """ + try: + wrapped_ssl_socket = self._create_ssl_socket( + raw_socket, + server_side, + ) + self._perform_handshake(wrapped_ssl_socket, raw_socket) + + underlying_socket = wrapped_ssl_socket + if hasattr(wrapped_ssl_socket, '_sock'): + underlying_socket = wrapped_ssl_socket._sock + + return TLSSocket( + ssl_socket=wrapped_ssl_socket, + raw_socket=underlying_socket, + context=self.context, + ) + except errors.NoSSLError: + # Plain HTTP detected, let it propagate for proper error response + raise + except TimeoutError: + # Handshake timeout, let it propagate + with suppress(Exception): + raw_socket.close() + raise + except (ssl.SSLError, OSError) as e: + # SSL handshake or socket error - clean up and raise + with suppress(Exception): + raw_socket.close() + raise errors.FatalSSLAlert(f'SSL wrapping failed: {e}') from e + + def _check_for_plain_http(self, raw_socket): + """Check if the client sent plain HTTP by peeking at first bytes. + + This is a best-effort check to provide a helpful error message when + clients accidentally use HTTP on an HTTPS port. If we can't detect + plain HTTP (timeout, no data yet, etc), we return False and let the + SSL handshake proceed, which will fail with its own error. + + Returns: + bool: True if plain HTTP is detected, False otherwise + """ + PEEK_BYTES = 16 + PEEK_TIMEOUT = 0.5 + + original_timeout = raw_socket.gettimeout() + try: + raw_socket.settimeout(PEEK_TIMEOUT) + first_bytes = raw_socket.recv(PEEK_BYTES, socket.MSG_PEEK) + + if not first_bytes: + return False + + http_methods = ( + b'GET ', + b'POST ', + b'PUT ', + b'DELETE ', + b'HEAD ', + b'OPTIONS ', + b'PATCH ', + b'CONNECT ', + b'TRACE ', + ) + return any( + first_bytes.startswith(method) for method in http_methods + ) + except (OSError, socket.timeout): + return False + finally: + raw_socket.settimeout(original_timeout) + + def _create_ssl_socket(self, raw_socket, server_side): + """Create SSL socket without handshake.""" + try: + # Manual handshake for error handling + return self.context.wrap_socket( + raw_socket, + do_handshake_on_connect=False, + server_side=server_side, + ) + except ssl.SSLError as e: + raise errors.FatalSSLAlert( + f'Error creating SSL socket: {e}', + ) from e + + def _perform_handshake(self, ssl_socket, raw_socket): + """Perform SSL handshake with error handling and retries.""" + HANDSHAKE_TIMEOUT = 5.0 + + # Set timeout on the SSL socket for the handshake + original_timeout = ssl_socket.gettimeout() + ssl_socket.settimeout(HANDSHAKE_TIMEOUT) + + try: + while True: + try: # noqa: WPS225 + ssl_socket.do_handshake() + return + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + direction = ( + 'read' + if isinstance(e, ssl.SSLWantReadError) + else 'write' + ) + self._wait_for_handshake_data(raw_socket, direction) + except socket.timeout as e: + raise errors.NoSSLError( + 'SSL handshake timeout.', + ) from e + except ssl.SSLEOFError as e: + raise errors.NoSSLError( + 'Peer closed connection during handshake.', + ) from e + except ssl.SSLError as e: + self._handle_ssl_error(e) + except OSError as e: + raise errors.FatalSSLAlert( + f'TCP error during handshake: {e}', + ) from e + finally: + # Restore original timeout + with suppress(Exception): + ssl_socket.settimeout(original_timeout) + + def _wait_for_handshake_data(self, raw_socket, direction): + """Wait for socket to be ready for read or write during handshake.""" + import select + + HANDSHAKE_TIMEOUT = 5.0 + fileno = raw_socket.fileno() + + if direction == 'read': + ready = select.select([fileno], [], [], HANDSHAKE_TIMEOUT)[0] + else: # write + ready = select.select([], [fileno], [], HANDSHAKE_TIMEOUT)[1] + + if not ready: + raise TimeoutError( + f'Handshake failed: Peer did not send expected data ({direction}).', + ) + + def _handle_ssl_error(self, error): + """Handle SSL errors during handshake.""" + err_str = str(error).lower() + + # Check for common patterns indicating plain HTTP + if any( + pattern in err_str + for pattern in ( + 'wrong version number', + 'http request', + 'unknown protocol', + ) + ): + raise errors.NoSSLError( + 'Client sent plain HTTP request', + ) from error + + raise errors.FatalSSLAlert( + f'Fatal SSL error during handshake: {error}', + ) from error @property def context(self): - """:py:class:`~ssl.SSLContext` that will be used to wrap sockets.""" + """Get the SSL context.""" return self._context @context.setter - def context(self, context): - """Set the ssl ``context`` to use.""" - self._context = context - # Python 3.7+ - # if a context is provided via `cherrypy.config.update` then - # `self.context` will be set after `__init__` - # use a property to intercept it to add an SNI callback - # but don't override the user's callback - # TODO: chain callbacks - with suppress(AttributeError): - if ssl.HAS_SNI and context.sni_callback is None: - context.sni_callback = _sni_callback + def context(self, value): + """Set the SSL context (for testing).""" + self._context = value + + def _create_context(self): + """Return an py:class:`ssl.SSLContext` from self attributes.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + + # Only attempt to load the key/cert chain if a certificate + # path is available. + if self.certificate: + try: + ctx.load_cert_chain( + self.certificate, + self.private_key, + self.private_key_password, + ) + except FileNotFoundError as file_err: + raise FileNotFoundError( + f'SSL certificate file not found: {file_err}', + ) from file_err - def bind(self, sock): - """Wrap and return the given socket.""" - return super(BuiltinSSLAdapter, self).bind(sock) + # Load CA Trust/Verification Chain + # (Optional, independent of server cert) + # This is needed for verifying client certificates + # or connecting to other services. + if self.certificate_chain: + ctx.load_verify_locations(cafile=self.certificate_chain) + + # Set Ciphers (Optional, independent of server cert) + if self.ciphers: + ctx.set_ciphers(self.ciphers) + + return ctx + + # ======================================================================== + # Adapter-specific environment variable methods + # ======================================================================== + + def _get_library_version_environ(self): + """ + Get SSL library version information. + + Overrides base class method to provide builtin ssl module version. + """ + python_version = sys.version.split()[0] + return { + 'SSL_VERSION_INTERFACE': f'Python/{python_version} {ssl.OPENSSL_VERSION}', + 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, + } + + def _get_optional_environ(self, conn): + """ + Get optional environment variables. + + Overrides base class method for builtin ssl-specific handling. + """ + environ = {} + + # Compression (note: most modern OpenSSL builds disable compression) + try: + compression = conn.compression() + if compression: + environ['SSL_COMPRESS_METHOD'] = compression + except AttributeError: + # TLSSocket might not have compression method + ... + + # SNI (Server Name Indication) if available + try: + server_hostname = conn.server_hostname + if server_hostname: + environ['SSL_TLS_SNI'] = server_hostname + except AttributeError: + ... + + return environ + + def _get_server_cert_environ(self): + """Get server certificate info using builtin ssl certificate parsing.""" + if not self.certificate: + return {} + + # Check if cryptography is available + if x509 is None: + return {} - def wrap(self, sock): - """Wrap and return the given socket, plus WSGI environ entries.""" try: - s = self.context.wrap_socket( - sock, - do_handshake_on_connect=True, - server_side=True, + with open(self.certificate, 'rb') as cert_file: + cert_data = cert_file.read() + + cert = x509.load_pem_x509_certificate( + cert_data, + default_backend(), ) - except ( - ssl.SSLEOFError, - ssl.SSLZeroReturnError, - ) as tls_connection_drop_error: - raise errors.FatalSSLAlert( - *tls_connection_drop_error.args, - ) from tls_connection_drop_error - except ssl.SSLError as generic_tls_error: - peer_speaks_plain_http_over_https = ( - generic_tls_error.errno == ssl.SSL_ERROR_SSL - and _assert_ssl_exc_contains(generic_tls_error, 'http request') + + return parse_x509_cert_to_environ(cert, 'SSL_SERVER') + + except Exception: + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """Populate the WSGI environment with client certificate details.""" + # 1. Access the raw ssl.SSLSocket object + try: + # 'conn' is the TLSSocket wrapper; '_sock' is the raw ssl.SSLSocket + raw_ssl_socket = conn._sock + except AttributeError: + # If the socket is missing or already closed + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # 2. Get the peer certificate details + try: + # getpeercert() returns a dict if a cert was presented, + # None otherwise. + peer_cert_details = raw_ssl_socket.getpeercert(binary_form=False) + # Also get the binary (DER) form to convert to PEM + peer_cert_binary = raw_ssl_socket.getpeercert(binary_form=True) + except ssl.SSLError: + # This occurs if verification failed during the handshake + ssl_environ['SSL_CLIENT_VERIFY'] = 'FAILURE' + return ssl_environ + except Exception: + # Catch any other socket errors + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # --- Check Verification Status --- + if peer_cert_details: + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + else: + # No cert presented + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # --- Add the PEM-encoded certificate --- + if peer_cert_binary: + ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( + peer_cert_binary, ) - if peer_speaks_plain_http_over_https: - reraised_connection_drop_exc_cls = errors.NoSSLError - else: - reraised_connection_drop_exc_cls = errors.FatalSSLAlert - raise reraised_connection_drop_exc_cls( - *generic_tls_error.args, - ) from generic_tls_error - except OSError as tcp_connection_drop_error: - raise errors.FatalSSLAlert( - *tcp_connection_drop_error.args, - ) from tcp_connection_drop_error - - return s, self.get_environ(s) - - def get_environ(self, sock): - """Create WSGI environ entries to be merged into each request.""" - cipher = sock.cipher() - ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', - 'SSL_PROTOCOL': cipher[1], - 'SSL_CIPHER': cipher[0], - 'SSL_CIPHER_EXPORT': '', - 'SSL_CIPHER_USEKEYSIZE': cipher[2], - 'SSL_VERSION_INTERFACE': '%s Python/%s' - % ( - HTTPServer.version, - sys.version, - ), - 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, - 'SSL_CLIENT_VERIFY': 'NONE', - # 'NONE' - client did not provide a cert (overriden below) - } + # --- Populate Metadata using existing utility methods --- - # Python 3.3+ - with suppress(AttributeError): - compression = sock.compression() - if compression is not None: - ssl_environ['SSL_COMPRESS_METHOD'] = compression - - # Python 3.6+ - with suppress(AttributeError): - ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex() - with suppress(AttributeError): - target_cipher = cipher[:2] - for cip in sock.context.get_ciphers(): - if target_cipher == (cip['name'], cip['protocol']): - ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits'] - break - - # Python 3.7+ sni_callback - with suppress(AttributeError): - ssl_environ['SSL_TLS_SNI'] = sock.sni - - if self.context and self.context.verify_mode != ssl.CERT_NONE: - client_cert = sock.getpeercert() - if client_cert: - # builtin ssl **ALWAYS** validates client certificates - # and terminates the connection on failure - ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' - ssl_environ.update( - self._make_env_cert_dict('SSL_CLIENT', client_cert), - ) - ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( - sock.getpeercert(binary_form=True), - ).strip() + # 3. Populate Subject DN + subject_dn_nested = peer_cert_details.get('subject', []) + subject_env = self._make_env_dn_dict( + env_prefix='SSL_CLIENT_S_DN', + cert_value=subject_dn_nested, + ) + ssl_environ.update(subject_env) - ssl_environ.update(self._server_env) + # 4. Populate Issuer DN + issuer_dn_nested = peer_cert_details.get('issuer', []) + issuer_env = self._make_env_dn_dict( + env_prefix='SSL_CLIENT_I_DN', + cert_value=issuer_dn_nested, + ) + ssl_environ.update(issuer_env) - # not supplied by the Python standard library (as of 3.8) - # - SSL_SESSION_RESUMED - # - SSL_SECURE_RENEG - # - SSL_CLIENT_CERT_CHAIN_n - # - SRP_USER - # - SRP_USERINFO + # 5. Populate other cert details + ssl_environ['SSL_CLIENT_M_VERSION'] = str( + peer_cert_details.get('version', ''), + ) + ssl_environ['SSL_CLIENT_M_SERIAL'] = str( + peer_cert_details.get('serialNumber', ''), + ) return ssl_environ @@ -482,8 +578,10 @@ def _make_env_dn_dict(self, env_prefix, cert_value): dn_attrs.setdefault(attr_code, []) dn_attrs[attr_code].append(val) + dn_string = '/'.join(dn) + env = { - env_prefix: ','.join(dn), + env_prefix: '/%s' % (dn_string,), } for attr_code, values in dn_attrs.items(): env['%s_%s' % (env_prefix, attr_code)] = ','.join(values) @@ -492,8 +590,3 @@ def _make_env_dn_dict(self, env_prefix, cert_value): for i, val in enumerate(values): env['%s_%s_%i' % (env_prefix, attr_code, i)] = val return env - - def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): - """Return socket file object.""" - cls = StreamReader if 'r' in mode else StreamWriter - return cls(sock, mode, bufsize) diff --git a/cheroot/ssl/builtin.pyi b/cheroot/ssl/builtin.pyi index b05aaf5ad7..cb5f508d26 100644 --- a/cheroot/ssl/builtin.pyi +++ b/cheroot/ssl/builtin.pyi @@ -2,8 +2,6 @@ from typing import Any from . import Adapter -DEFAULT_BUFFER_SIZE: int - class BuiltinSSLAdapter(Adapter): CERT_KEY_TO_ENV: Any CERT_KEY_TO_LDAP_CODE: Any @@ -22,5 +20,4 @@ class BuiltinSSLAdapter(Adapter): def context(self, context) -> None: ... def bind(self, sock): ... def wrap(self, sock): ... - def get_environ(self, sock): ... - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict: ... diff --git a/cheroot/ssl/pyopenssl.py b/cheroot/ssl/pyopenssl.py index 8a041a79b8..61cc1dd7c1 100644 --- a/cheroot/ssl/pyopenssl.py +++ b/cheroot/ssl/pyopenssl.py @@ -13,8 +13,7 @@ Method One ---------- - * :py:attr:`ssl_adapter.context - `: an instance of + * ``ssl_adapter.context``: an instance of :py:class:`SSL.Context `. If this is not None, it is assumed to be an :py:class:`SSL.Context @@ -40,22 +39,24 @@ `: the file name of the server's private key file. -Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context -` is :py:data:`None`, -but ``.private_key`` and ``.certificate`` are both given and valid, they -will be read, and the context will be automatically created from them. +Both are :py:data:`None` by default. If ``ssl_adapter.context`` +is :py:data:`None`, but ``.private_key`` and ``.certificate`` are both +given and valid, they will be read, and the context will be automatically +created from them. .. spelling:: pyopenssl """ -import socket +import select import sys -import threading -import time +from contextlib import suppress from warnings import warn as _warn +from . import Adapter, parse_pyopenssl_cert_to_environ +from .tls_socket import TLSSocket + try: import OpenSSL.version @@ -67,211 +68,21 @@ ssl_conn_type = SSL.ConnectionType except ImportError: SSL = None + crypto = None -import contextlib - -from .. import ( - errors, - server as cheroot_server, -) +from .. import errors from ..makefile import StreamReader, StreamWriter -from . import Adapter - - -class SSLFileobjectMixin: - """Base mixin for a TLS socket stream.""" - - ssl_timeout = 3 - ssl_retry = 0.01 - - # FIXME: - def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 - """Wrap the given call with TLS error-trapping. - - is_reader: if False EOF errors will be raised. If True, EOF errors - will return "" (to emulate normal sockets). - """ - start = time.time() - while True: - try: - return call(*args, **kwargs) - except SSL.WantReadError: - # Sleep and try again. This is dangerous, because it means - # the rest of the stack has no way of differentiating - # between a "new handshake" error and "client dropped". - # Note this isn't an endless loop: there's a timeout below. - # Ref: https://stackoverflow.com/a/5133568/595220 - time.sleep(self.ssl_retry) - except SSL.WantWriteError: - time.sleep(self.ssl_retry) - except SSL.SysCallError as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' - - errnum = e.args[0] - if is_reader and errnum in errors.socket_errors_to_ignore: - return b'' - raise socket.error(errnum) - except SSL.Error as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' - thirdarg = None - with contextlib.suppress(IndexError): - thirdarg = e.args[0][0][2] - - if thirdarg == 'http request': - # The client is talking HTTP to an HTTPS server. - raise errors.NoSSLError - - raise errors.FatalSSLAlert(*e.args) - - if time.time() - start > self.ssl_timeout: - raise socket.timeout('timed out') - - def recv(self, size): - """Receive message of a size from the socket.""" - return self._safe_call( - True, - super(SSLFileobjectMixin, self).recv, - size, - ) - - def readline(self, size=-1): - """Receive message of a size from the socket. - - Matches the following interface: - https://docs.python.org/3/library/io.html#io.IOBase.readline - """ - return self._safe_call( - True, - super(SSLFileobjectMixin, self).readline, - size, - ) - - def sendall(self, *args, **kwargs): - """Send whole message to the socket.""" - return self._safe_call( - False, - super(SSLFileobjectMixin, self).sendall, - *args, - **kwargs, - ) - def send(self, *args, **kwargs): - """Send some part of message to the socket.""" - return self._safe_call( - False, - super(SSLFileobjectMixin, self).send, - *args, - **kwargs, - ) - - -class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): +class SSLFileobjectStreamReader(StreamReader): """SSL file object attached to a socket object.""" -class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): +class SSLFileobjectStreamWriter(StreamWriter): """SSL file object attached to a socket object.""" -class SSLConnectionProxyMeta: - """Metaclass for generating a bunch of proxy methods.""" - - def __new__(mcl, name, bases, nmspc): - """Attach a list of proxy methods to a new class.""" - proxy_methods = ( - 'get_context', - 'pending', - 'send', - 'write', - 'recv', - 'read', - 'renegotiate', - 'bind', - 'listen', - 'connect', - 'accept', - 'setblocking', - 'fileno', - 'close', - 'get_cipher_list', - 'getpeername', - 'getsockname', - 'getsockopt', - 'setsockopt', - 'makefile', - 'get_app_data', - 'set_app_data', - 'state_string', - 'sock_shutdown', - 'get_peer_certificate', - 'want_read', - 'want_write', - 'set_connect_state', - 'set_accept_state', - 'connect_ex', - 'sendall', - 'settimeout', - 'gettimeout', - 'shutdown', - ) - proxy_methods_no_args = ('shutdown',) - - proxy_props = ('family',) - - def lock_decorator(method): - """Create a proxy method for a new class.""" - - def proxy_wrapper(self, *args): - self._lock.acquire() - try: - new_args = ( - args[:] if method not in proxy_methods_no_args else [] - ) - return getattr(self._ssl_conn, method)(*new_args) - finally: - self._lock.release() - - return proxy_wrapper - - for m in proxy_methods: - nmspc[m] = lock_decorator(m) - nmspc[m].__name__ = m - - def make_property(property_): - """Create a proxy method for a new class.""" - - def proxy_prop_wrapper(self): - return getattr(self._ssl_conn, property_) - - proxy_prop_wrapper.__name__ = property_ - return property(proxy_prop_wrapper) - - for p in proxy_props: - nmspc[p] = make_property(p) - - # Doesn't work via super() for some reason. - # Falling back to type() instead: - return type(name, bases, nmspc) - - -class SSLConnection(metaclass=SSLConnectionProxyMeta): - r"""A thread-safe wrapper for an ``SSL.Connection``. - - :param tuple args: the arguments to create the wrapped \ - :py:class:`SSL.Connection(*args) \ - ` - """ - - def __init__(self, *args): - """Initialize SSLConnection instance.""" - self._ssl_conn = SSL.Connection(*args) - self._lock = threading.RLock() - - -class pyOpenSSLAdapter(Adapter): +class pyOpenSSLAdapter(Adapter): # noqa: WPS214 """A wrapper for integrating :doc:`pyOpenSSL `.""" certificate = None @@ -286,11 +97,6 @@ class pyOpenSSLAdapter(Adapter): This is needed for cheaper "chained root" TLS certificates, and should be left as :py:data:`None` if not required.""" - context = None - """ - An instance of :py:class:`SSL.Context `. - """ - ciphers = None """The ciphers list of TLS.""" @@ -310,30 +116,150 @@ def __init__( if SSL is None: raise ImportError('You must install pyOpenSSL to use HTTPS.') - super(pyOpenSSLAdapter, self).__init__( + super().__init__( certificate, private_key, certificate_chain, ciphers, private_key_password=private_key_password, ) - self._environ = None + self.context = None def bind(self, sock): - """Wrap and return the given socket.""" - if self.context is None: - self.context = self.get_context() - conn = SSLConnection(self.context, sock) - self._environ = self.get_environ() - return conn + """ + Prepare the server socket. + + Ensures that the SSL context object is created + and fully configured. For Method One the caller + supplies the context at ``__init()__`` but for + Method Two we construct from certificate files. + """ + if self._context is None: + # Method Two + _ = self.context # triggers initialization via property + return sock def wrap(self, sock): - """Wrap and return the given socket, plus WSGI environ entries.""" - # pyOpenSSL doesn't perform the handshake until the first read/write - # forcing the handshake to complete tends to result in the connection - # closing so we can't reliably access protocol/client cert for the env - return sock, self._environ.copy() + """Wrap client socket with SSL and return environ entries.""" + tls_socket = self._wrap_with_pyopenssl(sock) + ssl_environ = self.get_environ(tls_socket) + return tls_socket, ssl_environ + + def _wrap_with_pyopenssl(self, raw_socket, server_side=True): + """Create a TLSSocket wrapping a PyOpenSSL connection.""" + pyopenssl_ssl_object = self._create_pyopenssl_connection(raw_socket) + self._configure_connection_state(pyopenssl_ssl_object, server_side) + self._perform_handshake(pyopenssl_ssl_object, raw_socket) + + # lgtm[py/insecure-protocol] + return TLSSocket( + ssl_socket=pyopenssl_ssl_object, + raw_socket=raw_socket, + context=self.context, + ) + + def _create_pyopenssl_connection(self, raw_socket): + """Create PyOpenSSL connection object.""" + try: + ssl_object = ssl_conn_type(self.context, raw_socket) + except SSL.Error as e: + raise errors.FatalSSLAlert( + f'Error creating pyOpenSSL connection: {e}', + ) from e + + ssl_object.setblocking(True) + return ssl_object + + def _configure_connection_state(self, ssl_object, server_side): + """Set connection to server or client mode.""" + if server_side: + ssl_object.set_accept_state() + else: + ssl_object.set_connect_state() + + def _perform_handshake(self, ssl_object, raw_socket): + """Perform SSL handshake with error handling.""" + while True: + try: + ssl_object.do_handshake() + return + except SSL.WantReadError: + self._wait_for_handshake_data(raw_socket) + except SSL.ZeroReturnError as e: + raise errors.NoSSLError( + 'Peer closed connection during handshake.', + ) from e + except SSL.Error as e: + self._handle_ssl_error(e) + except OSError as e: + raise errors.FatalSSLAlert( + f'TCP error during handshake: {e}', + ) from e + + def _wait_for_handshake_data(self, raw_socket): + """Wait for peer to send data during handshake.""" + HANDSHAKE_TIMEOUT = 5.0 + fileno = raw_socket.fileno() + ready_to_read, _, _ = select.select( + [fileno], + [], + [], + HANDSHAKE_TIMEOUT, + ) + + if not ready_to_read: + raise TimeoutError( + 'Handshake failed: Peer did not send expected data.', + ) + + def _handle_ssl_error(self, error): + """Handle SSL errors during handshake.""" + err_str = str(error) + if 'http request' in err_str: + raise errors.NoSSLError( + 'Client sent plain HTTP request', + ) from error + raise errors.FatalSSLAlert( + f'Fatal SSL error during handshake: {err_str}', + ) from error + + @property + def context(self): + """Get the SSL context.""" + if self._context is None: + # Method Two: auto-create from certificate/private_key + self._context = self.get_context() + return self._context + + @context.setter + def context(self, value): + """Set the SSL context (Method One).""" + self._context = value + + def get_context(self): + """Return an SSL.Context from self attributes. + + Uses TLS_SERVER_METHOD which supports TLS 1.0-1.3, but immediately + disables insecure protocols (SSLv2, SSLv3, TLSv1.0, TLSv1.1) via + set_options(), ensuring only TLS 1.2+ is accepted. + """ + c = SSL.Context(SSL.TLS_SERVER_METHOD) # nosec B502 + + # Disable all insecure protocols (SSLv2, SSLv3, TLSv1.0, TLSv1.1) + c.set_options( + SSL.OP_NO_SSLv2 + | SSL.OP_NO_SSLv3 + | SSL.OP_NO_TLSv1 + | SSL.OP_NO_TLSv1_1, + ) + + c.set_passwd_cb(self._password_callback, self.private_key_password) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c def _password_callback( self, @@ -343,7 +269,7 @@ def _password_callback( /, ): """Pass a passphrase to password protected private key.""" - b_password = b'' # returning a falsy value communicates an error + b_password = b'' if isinstance(password, str): b_password = password.encode('utf-8') elif isinstance(password, bytes): @@ -357,95 +283,81 @@ def _password_callback( UserWarning, stacklevel=1, ) - return b_password - def get_context(self): - """Return an ``SSL.Context`` from self attributes. - - Ref: :py:class:`SSL.Context ` - """ - # See https://code.activestate.com/recipes/442473/ - c = SSL.Context(SSL.SSLv23_METHOD) - c.set_passwd_cb(self._password_callback, self.private_key_password) - c.use_privatekey_file(self.private_key) - if self.certificate_chain: - c.load_verify_locations(self.certificate_chain) - c.use_certificate_file(self.certificate) - return c + # ======================================================================== + # Adapter-specific environment variable methods + # ======================================================================== - def get_environ(self): - """Return WSGI environ entries to be merged into each request.""" - ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', + def _get_library_version_environ(self): + """Get SSL library version information for pyOpenSSL.""" + return { 'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % ( - cheroot_server.HTTPServer.version, + 'Cheroot', OpenSSL.version.__title__, OpenSSL.version.__version__, sys.version, ), - 'SSL_VERSION_LIBRARY': SSL.SSLeay_version( - SSL.SSLEAY_VERSION, - ).decode(), + 'SSL_VERSION_LIBRARY': SSL.OpenSSL_version( + SSL.OPENSSL_VERSION, + ).decode('ascii'), } - if self.certificate: - # Server certificate attributes + def _get_optional_environ(self, conn): + """Get optional environment variables for pyOpenSSL.""" + # pyOpenSSL doesn't easily expose SNI or compression info + # Could be extended in the future + return {} + + def _get_server_cert_environ(self): + """Get server certificate info using pyOpenSSL certificate parsing.""" + if not self.certificate or crypto is None: + return {} + + try: with open(self.certificate, 'rb') as cert_file: - cert = crypto.load_certificate( - crypto.FILETYPE_PEM, - cert_file.read(), + cert_data = cert_file.read() + + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_data) + return parse_pyopenssl_cert_to_environ(cert, 'SSL_SERVER') + + except Exception: + # If certificate parsing fails, return empty dict + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """Add client certificate details using pyOpenSSL.""" + with suppress(Exception): + # Get the peer certificate from the pyOpenSSL connection + # conn is a TLSSocket, so we need to access the + # underlying SSL socket + ssl_socket = conn._ssl_socket + + # Check if peer verification was enabled + if ssl_socket.get_context().get_verify_mode() == SSL.VERIFY_NONE: + return ssl_environ + + client_cert = ssl_socket.get_peer_certificate() + + if client_cert: + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + ssl_environ.update( + parse_pyopenssl_cert_to_environ( + client_cert, + 'SSL_CLIENT', + ), ) - ssl_environ.update( - { - 'SSL_SERVER_M_VERSION': cert.get_version(), - 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), - # 'SSL_SERVER_V_START': - # Validity of server's certificate (start time), - # 'SSL_SERVER_V_END': - # Validity of server's certificate (end time), - }, - ) - - for prefix, dn in [ - ('I', cert.get_issuer()), - ('S', cert.get_subject()), - ]: - # X509Name objects don't seem to have a way to get the - # complete DN string. Use str() and slice it instead, - # because str(dn) == "" - dnstr = str(dn)[18:-2] - - wsgikey = 'SSL_SERVER_%s_DN' % prefix - ssl_environ[wsgikey] = dnstr - - # The DN should be of the form: /k1=v1/k2=v2, but we must allow - # for any value to contain slashes itself (in a URL). - while dnstr: - pos = dnstr.rfind('=') - dnstr, value = dnstr[:pos], dnstr[pos + 1 :] - pos = dnstr.rfind('/') - dnstr, key = dnstr[:pos], dnstr[pos + 1 :] - if key and value: - wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) - ssl_environ[wsgikey] = value + # Get PEM representation of certificate + pem_cert = ( + crypto.dump_certificate( + crypto.FILETYPE_PEM, + client_cert, + ) + .decode('ascii') + .strip() + ) + ssl_environ['SSL_CLIENT_CERT'] = pem_cert return ssl_environ - - def makefile(self, sock, mode='r', bufsize=-1): - """Return socket file object.""" - cls = ( - SSLFileobjectStreamReader - if 'r' in mode - else SSLFileobjectStreamWriter - ) - if SSL and isinstance(sock, ssl_conn_type): - wrapped_socket = cls(sock, mode, bufsize) - wrapped_socket.ssl_timeout = sock.gettimeout() - return wrapped_socket - # This is from past: - # TODO: figure out what it's meant for - return cheroot_server.CP_fileobject(sock, mode, bufsize) diff --git a/cheroot/ssl/pyopenssl.pyi b/cheroot/ssl/pyopenssl.pyi index 59dae05ae7..f1472bdc08 100644 --- a/cheroot/ssl/pyopenssl.pyi +++ b/cheroot/ssl/pyopenssl.pyi @@ -7,22 +7,8 @@ from . import Adapter ssl_conn_type: Type[SSL.Connection] -class SSLFileobjectMixin: - ssl_timeout: int - ssl_retry: float - def recv(self, size): ... - def readline(self, size: int = ...): ... - def sendall(self, *args, **kwargs): ... - def send(self, *args, **kwargs): ... - -class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): ... # type:ignore[misc] -class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): ... # type:ignore[misc] - -class SSLConnectionProxyMeta: - def __new__(mcl, name, bases, nmspc): ... - -class SSLConnection: - def __init__(self, *args) -> None: ... +class SSLFileobjectStreamReader(StreamReader): ... # type:ignore[misc] +class SSLFileobjectStreamWriter(StreamWriter): ... # type:ignore[misc] class pyOpenSSLAdapter(Adapter): def __init__( @@ -43,6 +29,5 @@ class pyOpenSSLAdapter(Adapter): password: bytes | str | None, /, ) -> bytes: ... - def get_environ(self): ... - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict: ... def get_context(self) -> SSL.Context: ... diff --git a/cheroot/ssl/tls_socket.py b/cheroot/ssl/tls_socket.py new file mode 100644 index 0000000000..c0051ed4aa --- /dev/null +++ b/cheroot/ssl/tls_socket.py @@ -0,0 +1,492 @@ +""" +A unified SSL/TLS socket layer for Cheroot. + +This module provides a TLSSocket class that abstracts over +different SSL/TLS implementations, such as Python's built-in ssl module +and pyOpenSSL. It offers a consistent interface for the rest of +the Cheroot server code. +""" + +import errno +import io +import os +import socket +import ssl +import time +from contextlib import suppress + +from .. import errors + + +try: + from OpenSSL import SSL, crypto +except ImportError: + SSL = None # type: ignore[assignment] + crypto = None # type: ignore[assignment] + ssl_conn_type = None # type: ignore[misc] +else: + # If the import succeeded, proceed with secondary checks + # Use a separate try/except block for the connection type logic + try: # noqa: WPS505 + ssl_conn_type = SSL.Connection # type: ignore[misc] + except AttributeError: + # Fallback to older name if 'Connection' is not found + ssl_conn_type = SSL.ConnectionType # type: ignore[attr-defined] + +_OPENSSL_PROTOCOL_MAP = { + 769: 'TLSv1', + 770: 'TLSv1.1', + 771: 'TLSv1.2', + 772: 'TLSv1.2', + 773: 'TLSv1.3', +} + + +class TLSSocket(io.RawIOBase): # noqa: PLR0904 # pylint: disable=too-many-public-methods + """ + Lightweight wrapper around SSL/TLS sockets. + + Provides a uniform interface over both :class:`ssl.SSLSocket` and + :class:`OpenSSL.SSL.Connection` objects, ensuring consistent I/O + stream handling for Cheroot with proper SSL error handling. + """ + + def __init__(self, ssl_socket, raw_socket, context): + """ + Initialize TLS socket wrapper. + + Args: + ssl_socket: SSL/TLS wrapped socket (SSLSocket or pyOpenSSL Conn) + raw_socket: The underlying raw socket + context: The SSL context (SSLContext or pyOpenSSL Context) + """ + self._ssl_socket = ssl_socket + self._sock = raw_socket + self.context = context + self.ssl_retry = 0.01 # Retry interval for SSL errors + self.ssl_retry_max = 0.1 # Maximum time to retry before timeout + self._is_closed = False + + super().__init__() # Initialize RawIOBase + + # ================================================================ + # Properties - delegate to underlying socket or return stored values + # ================================================================ + + @property + def family(self): + """Get socket family.""" + return getattr(self._sock, 'family', socket.AF_INET) + + @property + def type(self): + """Get socket type.""" + return getattr(self._sock, 'type', socket.SOCK_STREAM) + + @property + def proto(self): + """Get socket protocol.""" + return getattr(self._sock, 'proto', 0) + + @property + def _closed(self): + """Check if the connection is closed.""" + if self._sock is None: + return True + + try: + fd = self._sock.fileno() + os.fstat(fd) + return False + except (OSError, AttributeError) as sockerr: + if isinstance(sockerr, OSError) and sockerr.errno == errno.EBADF: + return True + return True + + @property + def closed(self): + """Public closed property.""" + return self._is_closed + + @property + def scheme(self): + """Signal to Cheroot that this is an HTTPS connection.""" + return 'https' + + # ================================================================ + # SSL Error Handling + # ================================================================ + + def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 + r""" + Wrap the given call with TLS error-trapping. + + This handles transient SSL errors like WantReadError and WantWriteError + by retrying with a small sleep interval. + """ + if not SSL: + # If pyOpenSSL not available, just call directly + return call(*args, **kwargs) + + start = time.time() + + while True: + try: + return call(*args, **kwargs) + + except SSL.WantReadError: + # SSL needs more data to complete operation + time.sleep(self.ssl_retry) + if time.time() - start > self.ssl_retry_max: + raise socket.timeout('SSL WantReadError retry timeout') + + except SSL.WantWriteError: + # SSL needs to write data before continuing + time.sleep(self.ssl_retry) + if time.time() - start > self.ssl_retry_max: + raise socket.timeout('SSL WantWriteError retry timeout') + + except SSL.SysCallError as sys_err: + # System call error - check if it's ignorable + if is_reader and sys_err.args == (-1, 'Unexpected EOF'): + return b'' + + errnum = sys_err.args[0] if sys_err.args else -1 + if is_reader and errnum in errors.socket_errors_to_ignore: + return b'' + + raise socket.error(errnum) + + except SSL.Error as ssl_err: + # General SSL error - check for specific known errors + if not ssl_err.args: + raise + + error_list = ssl_err.args[0] + if not isinstance(error_list, list): + error_list = [error_list] + + for error_tuple in error_list: + if ( + not isinstance(error_tuple, tuple) + or len(error_tuple) < 3 + ): + continue + + error_message = error_tuple[2].lower() + + # HTTP request on HTTPS port + if 'http request' in error_message: + raise errors.NoSSLError + + # Fatal SSL alert + if 'alert' in error_message: + raise errors.FatalSSLAlert(str(ssl_err)) + + # Unknown SSL error + raise + + # ================================================================ + # Socket I/O methods with error handling + # ================================================================ + + def readable(self): + """Return True - this I/O object supports reading.""" + return True + + def writable(self): + """Return True - this I/O object supports writing.""" + return True + + def seekable(self): + """Return False - sockets are not seekable.""" + return False + + def recv(self, size): + """Receive data from the connection with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call(True, self._ssl_socket.recv, size) + # For ssl.SSLSocket, just call recv directly + # (it handles its own errors) + return self._ssl_socket.recv(size) + + def send(self, data, flags=0): + """Send data with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call(False, self._ssl_socket.send, data, flags) + return self._ssl_socket.send(data, flags) + + def sendall(self, data, flags=0): + """Send all data with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call( + False, + self._ssl_socket.sendall, + data, + flags, + ) + return self._ssl_socket.sendall(data, flags) + + def readinto(self, buff): + """ + Read data into a buffer - called by :class:`io.BufferedReader`. + + This is the key method that ``BufferedReader`` calls when reading. + By implementing this with error handling, we ensure SSL errors + are properly handled in the buffered I/O path. + + Args: + buff: Buffer to read data into (bytearray or memoryview) + + Returns: + Number of bytes read, or None for EOF + """ + data = self.recv(len(buff)) + if not data: + return 0 # EOF + num_bytes = len(data) + view = memoryview(buff) + view[:num_bytes] = data + return num_bytes + + def read(self, size): + """Read data from the connection. Used by StreamReader.""" + return self.recv(size) + + def write(self, data): + """Write data to the connection with SSL error handling.""" + return self.send(data) + + # ================================================================ + # Unified SSL/TLS methods - handle both backends + # ================================================================ + + def get_cipher_info(self): + """ + Get the current cipher information in a unified format. + + Returns: + tuple: (cipher_name, protocol_version, secret_bits) or None + """ + ssl_socket = self._ssl_socket + + if isinstance(ssl_socket, ssl.SSLSocket): + # Returns tuple: (cipher_name, protocol_version, secret_bits) + return ssl_socket.cipher() + + if SSL and isinstance(ssl_socket, ssl_conn_type): + try: + protocol_constant = ssl_socket.get_protocol_version() + protocol = _OPENSSL_PROTOCOL_MAP.get( + protocol_constant, + 'UNKNOWN', + ) + cipher_name = ssl_socket.get_cipher_name() + active_bits = ssl_socket.get_cipher_bits() + + return (cipher_name, protocol, active_bits) + + except SSL.Error: + return None + except Exception as err: + raise errors.FatalSSLAlert( + 'Error retrieving cipher info from pyOpenSSL connection: ' + + str(err), + ) from err + + return None + + def getpeercert(self, binary_form=False): + """ + Get the peer's certificate. + + Args: + binary_form: If True, return DER-encoded bytes; + else return dict/object + + Returns: + Certificate in requested format, or None/empty dict if unavailable + """ + if not hasattr(self, '_ssl_socket') or self._ssl_socket is None: + return None if binary_form else {} + + # Handle PyOpenSSL Connection + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + try: + cert = self._ssl_socket.get_peer_certificate() + if cert is None: + return None if binary_form else {} + + if binary_form: + return crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) + return cert + except Exception: + return None if binary_form else {} + + # Handle builtin ssl.SSLSocket + return self._ssl_socket.getpeercert(binary_form) + + def get_verify_mode(self): + """ + Get the certificate verification mode. + + Returns: + int: ssl.CERT_NONE, ssl.CERT_OPTIONAL, or ssl.CERT_REQUIRED + """ + ssl_socket = self._ssl_socket + + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.context.verify_mode + + if SSL and isinstance(ssl_socket, ssl_conn_type): + verify_mode = self.context.get_verify_mode() + + # Map PyOpenSSL constants to ssl module constants + if verify_mode == SSL.VERIFY_NONE: + return ssl.CERT_NONE + if verify_mode & SSL.VERIFY_PEER: + # Check if cert is actually required or just optional + if verify_mode & SSL.VERIFY_FAIL_IF_NO_PEER_CERT: + return ssl.CERT_REQUIRED # Cert must be provided + return ssl.CERT_OPTIONAL # Cert requested but not required + + return ssl.CERT_NONE + + # ================================================================ + # Socket control methods - explicit delegation + # ================================================================ + + def fileno(self): + """Return the file descriptor of the underlying socket.""" + if self._is_closed: + raise ValueError('Socket is closed') + if self._sock is None: + raise ValueError('Socket is not initialized') + + fd = self._sock.fileno() + if fd == -1: + raise OSError('Socket has been closed') + + return fd + + def getpeername(self): + """Return the address of the remote peer.""" + return self._sock.getpeername() + + def getsockname(self): + """Return the address of the local machine.""" + return self._sock.getsockname() + + def gettimeout(self): + """Get the timeout value.""" + return self._sock.gettimeout() + + def settimeout(self, timeout): + """Set timeout on the connection.""" + return self._ssl_socket.settimeout(timeout) + + def setblocking(self, flag): + """Set blocking mode.""" + return self._sock.setblocking(flag) + + def getsockopt(self, level, optname, buflen=None): + """Get socket option.""" + if buflen is not None: + return self._sock.getsockopt(level, optname, buflen) + return self._sock.getsockopt(level, optname) + + def makefile(self, *args, **kwargs): + """Create a file-like object from the connection.""" + return self._ssl_socket.makefile(*args, **kwargs) + + def shutdown(self, how): + """Perform a clean SSL shutdown.""" + if isinstance(self._ssl_socket, ssl.SSLSocket): + with suppress(Exception): + self._ssl_socket.unwrap() + + with suppress(Exception): + return self._sock.shutdown(how) + + def sock_shutdown(self, how): + """Shutdown the raw socket (TCP level), bypassing SSL shutdown.""" + # Windows error code for "not a socket" + WSAENOTSOCK = 10038 + try: + # Attempt to shutdown the underlying kernel socket + return self._sock.shutdown(how) + except OSError as err: + # errno 9 is EBADF (Bad file descriptor) + if err.errno in {errno.EBADF, WSAENOTSOCK}: + # The underlying socket was already closed, + # which is fine during cleanup. + # Silently ignore the error. + return None + # If it's another OSError, re-raise it. + raise + + def close(self): + """Close the connection.""" + if self._is_closed: + return + + self._is_closed = True + with suppress(Exception): + self._ssl_socket.close() + + # ================================================================ + # Additional SSL methods that might be called + # ================================================================ + + def compression(self): + """Get compression method (usually None for modern TLS).""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.compression() + # pyOpenSSL doesn't support this easily, return None + return None + + @property + def sni(self): + """Get SNI hostname if available.""" + ssl_socket = self._ssl_socket + + # SSLSocket doesn't expose SNI directly, return None + if isinstance(ssl_socket, ssl.SSLSocket): + return None + + # pyOpenSSL might have it via get_servername() + if SSL and isinstance(ssl_socket, ssl_conn_type): + with suppress(Exception): + return ssl_socket.get_servername() + return None + + def version(self): + """Get TLS version.""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.version() + if SSL and isinstance(ssl_socket, ssl_conn_type): + # Return string version + protocol_constant = ssl_socket.get_protocol_version() + return _OPENSSL_PROTOCOL_MAP.get(protocol_constant, 'UNKNOWN') + return None + + def get_session(self): + """Get SSL session for reuse (method form).""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.session # Property on SSLSocket + if SSL and isinstance(ssl_socket, ssl_conn_type): + return ssl_socket.get_session() # Method on pyOpenSSL + return None + + @property + def session(self): + """Get SSL session for reuse.""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.session + if SSL and isinstance(ssl_socket, ssl_conn_type): + return ssl_socket.get_session() + return None diff --git a/cheroot/test/conftest.py b/cheroot/test/conftest.py index 657997c3f2..8b534e7806 100644 --- a/cheroot/test/conftest.py +++ b/cheroot/test/conftest.py @@ -9,9 +9,20 @@ import pytest -from .._compat import IS_MACOS, IS_WINDOWS +import trustme +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import ( + BestAvailableEncryption, + Encoding, + PrivateFormat, + load_pem_private_key, +) + +from .._compat import IS_MACOS, IS_WINDOWS, ntou from ..server import Gateway, HTTPServer from ..testing import ( # noqa: F401 # pylint: disable=unused-import + ANY_INTERFACE_IPV4, + _get_conn_data, get_server_client, native_server, thread_and_native_server, @@ -107,3 +118,90 @@ def make_http_server(bind_addr): time.sleep(0.1) return httpserver + + +@pytest.fixture +def ca(): + """Provide a certificate authority via fixture.""" + return trustme.CA() + + +@pytest.fixture +def tls_ca_certificate_pem_path(ca): + """Provide a certificate authority certificate file via fixture.""" + with ca.cert_pem.tempfile() as ca_cert_pem: + yield ca_cert_pem + + +@pytest.fixture +def tls_certificate(ca): + """ + Generate a TLS server certificate for testing. + + Creates a certificate valid for 'test-server.local', 'localhost', + and '127.0.0.1' with ``CN=localhost``. + """ + interface, _host, _port = _get_conn_data(ANY_INTERFACE_IPV4) + identities = [ + 'test-server.local', + 'localhost', # This will be used for CN and SAN + ntou(interface), # This is '127.0.0.1' for SAN + ] + return ca.issue_server_cert(*identities, common_name='localhost') + + +@pytest.fixture +def tls_certificate_pem_path(tls_certificate): + """ + Return path to temp file containing the server certificate in PEM format. + + The file is automatically cleaned up after the test completes. + """ + # The 'cert_pem' property holds the leaf certificate data. + leaf_cert_blob = tls_certificate.cert_chain_pems[0] + + # Write to a file that persists for the test duration + with leaf_cert_blob.tempfile() as cert_pem_path: + yield cert_pem_path + + +@pytest.fixture +def tls_certificate_chain_pem_path(tls_certificate): + """Provide a certificate chain PEM file path via fixture.""" + with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_certificate_private_key_pem_path(tls_certificate): + """Provide a certificate private key PEM file path via fixture.""" + with tls_certificate.private_key_pem.tempfile() as cert_key_pem: + yield cert_key_pem + + +@pytest.fixture +def tls_certificate_passwd_private_key_pem_path( + tls_certificate, + private_key_password, + tmp_path, +): + """Return a certificate private key PEM file path.""" + key_as_bytes = tls_certificate.private_key_pem.bytes() + private_key_object = load_pem_private_key( + key_as_bytes, + password=None, + backend=default_backend(), + ) + + encrypted_key_as_bytes = private_key_object.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.PKCS8, + encryption_algorithm=BestAvailableEncryption( + password=private_key_password.encode('utf-8'), + ), + ) + + key_file = tmp_path / 'encrypted-private-key.pem' + key_file.write_bytes(encrypted_key_as_bytes) + + return key_file diff --git a/cheroot/test/ssl/test_ssl_builtin.py b/cheroot/test/ssl/test_ssl_builtin.py new file mode 100644 index 0000000000..80c22dd66b --- /dev/null +++ b/cheroot/test/ssl/test_ssl_builtin.py @@ -0,0 +1,443 @@ +"""Tests for ``cheroot.ssl.builtin``.""" + +import socket +import ssl +import threading +import time +from contextlib import closing + +import pytest + +from cheroot import errors +from cheroot.makefile import StreamReader, StreamWriter +from cheroot.ssl.builtin import BuiltinSSLAdapter +from cheroot.ssl.tls_socket import TLSSocket + + +_CONNECTION_TIMEOUT_SECONDS = 5.0 +_SOCKET_BUFFER_SIZE = 4096 + + +@pytest.mark.usefixtures('mocker') +def test_full_builtin_environ_population( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that builtin adapter populates all SSL environ variables correctly.""" + captured = {'environ': {}} + + def capture_wsgi_app(environ, start_response): + captured['environ'].update(environ) + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'Hello, secure world!'] + + bind_host = '127.0.0.1' + port = 0 + + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + + from cheroot.wsgi import Server as WSGIServer + + server = WSGIServer( + bind_addr=(bind_host, port), + wsgi_app=capture_wsgi_app, + ) + server.ssl_adapter = adapter + + server.prepare() + actual_port = server.bind_addr[1] + + server_thread = threading.Thread(target=server.serve, daemon=True) + server_thread.start() + + time.sleep(1) + + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + sock = socket.create_connection( + (bind_host, actual_port), + timeout=_CONNECTION_TIMEOUT_SECONDS, + ) + client_sock = context.wrap_socket(sock, server_hostname=bind_host) + + request = ( + b'GET /test/path?q=1 HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n\r\n' + ) + client_sock.sendall(request) + + client_sock.settimeout(0.5) + response = b'' + try: + while True: + chunk = client_sock.recv(_SOCKET_BUFFER_SIZE) + if not chunk: + break + response += chunk + except socket.timeout: + # Possible timeout when the server closes + # the connection after sending the response. + pass + + client_sock.close() + + finally: + time.sleep(0.5) + server.stop() + server_thread.join(timeout=2) + + captured_environ = captured['environ'] + + # HTTP Request variables + assert captured_environ.get('REQUEST_METHOD') == 'GET' + assert captured_environ.get('PATH_INFO') == '/test/path' + assert captured_environ.get('QUERY_STRING') == 'q=1' + assert captured_environ.get('SERVER_PROTOCOL') == 'HTTP/1.1' + + # SSL variables + assert 'SSL_PROTOCOL' in captured_environ + assert 'TLS' in captured_environ['SSL_PROTOCOL'] + + assert 'SSL_CIPHER' in captured_environ + + assert 'SSL_VERSION_LIBRARY' in captured_environ + assert 'OpenSSL' in captured_environ['SSL_VERSION_LIBRARY'] + + assert 'SSL_VERSION_INTERFACE' in captured_environ + assert 'Python' in captured_environ['SSL_VERSION_INTERFACE'] + + assert 'SSL_SERVER_M_SERIAL' in captured_environ + + assert 'SSL_SERVER_S_DN_CN' in captured_environ + assert captured_environ['SSL_SERVER_S_DN_CN'] == 'localhost' + + +@pytest.mark.usefixtures('mocker') +def test_wrap_with_builtin_ssl_wrap_fails( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that SSL wrap_socket error raises FatalSSLAlert.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + try: + + def failing_wrap_socket(sock, *args, **kwargs): + raise ssl.SSLError('SSL wrap failed') + + mocker.patch.object( + adapter.context, + 'wrap_socket', + side_effect=failing_wrap_socket, + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating SSL socket', + ): + adapter._wrap_with_builtin(server_sock) + + finally: + server_sock.close() + client_sock.close() + + +@pytest.mark.parametrize( + ('error_class', 'error_msg', 'expected_exception', 'match_text'), + ( + ( + ssl.SSLWantReadError, + 'The operation did not complete', + None, # Should retry, not raise + None, + ), + ( + ssl.SSLWantWriteError, + 'The operation did not complete', + None, # Should retry, not raise + None, + ), + ( + OSError, + 'Connection reset by peer', + errors.FatalSSLAlert, + 'TCP error during handshake', + ), + ( + ssl.SSLError, + 'wrong version number', + errors.NoSSLError, + 'Client sent plain HTTP request', + ), + ( + ssl.SSLError, + 'certificate verify failed', + errors.FatalSSLAlert, + 'Fatal SSL error during handshake', + ), + ( + ssl.SSLEOFError, + 'EOF occurred', + errors.NoSSLError, + 'Peer closed connection during handshake', + ), + ), + ids=[ + 'SSLWantReadError_retry', + 'SSLWantWriteError_retry', + 'OSError_connection_reset', + 'SSLError_wrong_version', + 'SSLError_cert_failed', + 'SSLEOFError', + ], +) +def test_builtin_handshake_error_handling( # pylint: disable=too-many-positional-arguments + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, + error_class, + error_msg, + expected_exception, + match_text, +): + """Test various error conditions during builtin SSL handshake.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + try: + handshake_attempts = {'count': 0} + + def failing_do_handshake(): + handshake_attempts['count'] += 1 + if handshake_attempts['count'] == 1: + raise error_class(error_msg) + # Second attempt succeeds (for retry cases) + + # Mock wrap_socket to return a mock SSL socket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.do_handshake = failing_do_handshake + mock_ssl_socket.context = adapter.context + + mocker.patch.object( + adapter.context, + 'wrap_socket', + return_value=mock_ssl_socket, + ) + + # Mock select for WantRead/WantWrite cases + if error_class in {ssl.SSLWantReadError, ssl.SSLWantWriteError}: + mocker.patch( + 'select.select', + return_value=( + [server_sock.fileno()], + [server_sock.fileno()], + [], + ), + ) + + if expected_exception: + with pytest.raises(expected_exception, match=match_text): + adapter._wrap_with_builtin(server_sock) + else: + # Should retry and succeed + tls_sock = adapter._wrap_with_builtin(server_sock) + assert isinstance(tls_sock, TLSSocket) + assert handshake_attempts['count'] >= 2, ( + f'Expected retry but only {handshake_attempts["count"]} attempts' + ) + + finally: + server_sock.close() + client_sock.close() + + +def test_builtin_handshake_timeout( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that handshake times out on repeated WantRead errors.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + with closing(server_sock), closing(client_sock): + # Mock wrap_socket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.do_handshake.side_effect = ssl.SSLWantReadError() + + mocker.patch.object( + adapter.context, + 'wrap_socket', + return_value=mock_ssl_socket, + ) + + # Mock select to return empty (timeout) + mocker.patch('select.select', return_value=([], [], [])) + + with pytest.raises(TimeoutError, match='Handshake failed'): + adapter._wrap_with_builtin(server_sock) + + +def test_builtin_create_ssl_socket_error( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that wrap_socket errors are caught and converted to FatalSSLAlert.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + # Create the context first so it has valid certs + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + with closing(server_sock), closing(client_sock): + # Now patch wrap_socket AFTER context is created + mocker.patch.object( + adapter.context, + 'wrap_socket', + side_effect=ssl.SSLError('Failed to create SSL socket'), + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating SSL socket', + ): + adapter._wrap_with_builtin(server_sock) + + +def test_builtin_adapter_get_server_cert_environ_no_cert( + tls_certificate_private_key_pem_path, +): + """Test that _get_server_cert_environ returns empty dict when no certificate.""" + adapter = BuiltinSSLAdapter( + None, # No certificate + tls_certificate_private_key_pem_path, + ) + + environ = adapter._get_server_cert_environ() + assert len(environ) == 0 + + +def test_builtin_adapter_get_server_cert_environ_invalid_file( + tls_certificate_private_key_pem_path, +): + """Test that _get_server_cert_environ handles invalid cert file gracefully.""" + with pytest.raises( + FileNotFoundError, + match='SSL certificate file not found', + ): + BuiltinSSLAdapter( + '/nonexistent/cert.pem', + tls_certificate_private_key_pem_path, + ) + + +def test_builtin_adapter_client_cert_no_verification( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that client cert environ is not populated when verification is disabled.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + + # Mock TLSSocket with no client cert verification + mock_conn = mocker.MagicMock() + + # IMPORTANT: Mock _sock, not _ssl_socket! + mock_conn._sock.getpeercert = mocker.MagicMock( + side_effect=lambda binary_form=False: None, + ) + mock_conn.context.verify_mode = ssl.CERT_NONE + + ssl_environ = {'HTTPS': 'on'} + environ = adapter._get_client_cert_environ(mock_conn, ssl_environ) + + assert environ['SSL_CLIENT_VERIFY'] == 'NONE' + assert 'SSL_CLIENT_CERT' not in environ + + +def test_streamreader_with_tls_and_regular_sockets(mocker): + """Test StreamReader works with both TLSSocket and regular sockets.""" + # Test with TLSSocket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mocker.MagicMock(), + context=mocker.MagicMock(), + ) + + buffered_reader = StreamReader(mock_tls_socket) + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + assert hasattr(buffered_reader, 'read') + + # Test with regular socket + regular_sock = mocker.MagicMock() + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + buffered_reader = StreamReader(regular_sock) + + # Verify SocketIO was called to wrap regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'rb') + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + + +def test_streamwriter_with_tls_and_regular_sockets(mocker): + """Test StreamWriter works with both TLSSocket and regular sockets.""" + # Test with TLSSocket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mocker.MagicMock(), + context=mocker.MagicMock(), + ) + + buffered_writer = StreamWriter(mock_tls_socket) + assert isinstance(buffered_writer, StreamWriter) + assert buffered_writer.bytes_written == 0 + assert hasattr(buffered_writer, 'write') + + # Test with regular socket + regular_sock = mocker.MagicMock() + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + buffered_writer = StreamWriter(regular_sock) + + # Verify SocketIO was called to wrap regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'wb') + assert isinstance(buffered_writer, StreamWriter) + assert buffered_writer.bytes_written == 0 diff --git a/cheroot/test/ssl/test_ssl_pyopenssl.py b/cheroot/test/ssl/test_ssl_pyopenssl.py new file mode 100644 index 0000000000..f7fbfb0d03 --- /dev/null +++ b/cheroot/test/ssl/test_ssl_pyopenssl.py @@ -0,0 +1,703 @@ +"""Tests for ``cheroot.ssl.pyopenssl``.""" +# Assuming OpenSSL is imported as 'SSL' in your test module + +import errno +import io +import socket +import ssl +import threading +import time +from contextlib import suppress + +import pytest + +from OpenSSL import SSL + +from cheroot import errors +from cheroot.makefile import StreamReader, StreamWriter +from cheroot.ssl.pyopenssl import pyOpenSSLAdapter +from cheroot.ssl.tls_socket import TLSSocket, ssl_conn_type + +# --- The Main Integration Test --- +from cheroot.wsgi import Server as WSGIServer + + +_CONNECTION_TIMEOUT_SECONDS = 5.0 +_SOCKET_BUFFER_SIZE = 4096 + + +@pytest.mark.usefixtures('mocker') +def test_full_pyopenssl_environ_population( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """ + Test pyOpenSSL adapter populates WSGI environ with HTTP and SSL variables. + + Performs an end-to-end test by: + - Starting a WSGI server with pyOpenSSL TLS adapter + - Making an SSL client connection + - Sending an HTTP request over TLS + - Verifying environ contains correct HTTP vars (METHOD, PATH, QUERY, etc.) + - Verifying environ contains SSL vars (PROTOCOL, CIPHER, VERSION, DN, etc.) + + This ensures the pyOpenSSL integration correctly exposes SSL connection + details to WSGI applications through the environ dictionary. + """ + captured = {'environ': {}} + + def capture_wsgi_app(environ, start_response): + captured['environ'].update(environ) + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'Hello, secure world!'] + + bind_host = '127.0.0.1' + port = 0 + + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ciphers='ALL', + ) + + # Use WSGIServer instead of HTTPServer + server = WSGIServer( + bind_addr=(bind_host, port), + wsgi_app=capture_wsgi_app, + ) + server.ssl_adapter = adapter + + # Prepare the server (binds the socket) + server.prepare() + actual_port = server.bind_addr[1] + + # Start the server in a thread + server_thread = threading.Thread(target=server.serve, daemon=True) + server_thread.start() + + time.sleep(1) # Give it time to start + + try: + # Connect using SSL + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + # === CLIENT SIDE === + # 1. Creating socket connection + sock = socket.create_connection( + (bind_host, actual_port), + timeout=_CONNECTION_TIMEOUT_SECONDS, + ) + + # 2. Wrapping with SSL and completing handshake + client_sock = context.wrap_socket(sock, server_hostname=bind_host) + + # 3. Create request + request = ( + b'GET /test/path?q=1 HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n\r\n' + ) + + # 4. Send request + client_sock.sendall(request) + + # 5. Read response + client_sock.settimeout(0.5) + response = b'' + try: + while True: + chunk = client_sock.recv(_SOCKET_BUFFER_SIZE) + if not chunk: + break + response += chunk + except socket.timeout: + # Possible timeout when the server closes + # the connection after sending the response. + pass + + client_sock.close() + + finally: + # Give server time to process + time.sleep(0.5) + + server.stop() + server_thread.join(timeout=2) + + captured_environ = captured['environ'] + + # Assertions - HTTP Request variables should be populated + assert captured_environ.get('REQUEST_METHOD') == 'GET', ( + f"Expected REQUEST_METHOD='GET', got '{captured_environ.get('REQUEST_METHOD')}'" + ) + + assert captured_environ.get('PATH_INFO') == '/test/path', ( + f"Expected PATH_INFO='/test/path', got '{captured_environ.get('PATH_INFO')}'" + ) + + assert captured_environ.get('QUERY_STRING') == 'q=1', ( + f"Expected empty QUERY_STRING, got '{captured_environ.get('QUERY_STRING')}'" + ) + + assert captured_environ.get('SERVER_PROTOCOL') == 'HTTP/1.1', ( + f"Expected SERVER_PROTOCOL='HTTP/1.1', got '{captured_environ.get('SERVER_PROTOCOL')}'" + ) + + # SSL variables should be populated + assert 'SSL_PROTOCOL' in captured_environ, 'SSL_PROTOCOL not in environ' + assert 'TLSv1' in captured_environ['SSL_PROTOCOL'] + + assert 'SSL_CIPHER' in captured_environ, 'SSL_CIPHER not in environ' + + assert 'SSL_VERSION_LIBRARY' in captured_environ + assert 'OpenSSL' in captured_environ['SSL_VERSION_LIBRARY'] + + assert 'SSL_VERSION_INTERFACE' in captured_environ + assert 'pyOpenSSL' in captured_environ['SSL_VERSION_INTERFACE'] + + assert 'SSL_SERVER_M_SERIAL' in captured_environ + + assert 'SSL_SERVER_S_DN_CN' in captured_environ, ( + 'SSL_SERVER_S_DN_CN not in environ' + ) + assert captured_environ['SSL_SERVER_S_DN_CN'] == 'localhost', ( + f"Expected CN='localhost', got '{captured_environ['SSL_SERVER_S_DN_CN']}'" + ) + + +# --- Parameterized Test Cases --- +test_cases = [ + ( + 'Success after WantReadError', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.WantReadError, b'OK'], + None, + b'OK', + 2, + ), + ( + 'Success after WantWriteError', + 'send', # Changed from 'safe_send' + 'send', + [SSL.WantWriteError, b'OK'], + None, + b'OK', + 2, + ), + ( + 'Timeout', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.WantReadError] * 4, + socket.timeout, + None, + 2, + ), + ( + 'SysCallError: Unexpected EOF', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(-1, 'Unexpected EOF')], + None, + b'', + 1, + ), + ( + 'SysCallError: Ignorable Socket Error (e.g., Broken Pipe)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(errno.EPIPE, 'Broken pipe')], + None, + b'', + 1, + ), + ( + 'SysCallError: Non-Ignorable Error (Connection Reset)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(errno.ENOTCONN, 'Socket is not connected')], + socket.error, + None, + 1, + ), + ( + 'SysCallError: Non-Ignorable Error (Writer)', + 'send', # Changed from 'safe_send' + 'send', + [SSL.SysCallError(999, 'Fatal system error')], + socket.error, + None, + 1, + ), + ( + 'SSL.Error: NoSSLError (HTTP Request)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.Error([(-1, 'SSL routines', 'http request')])], + errors.NoSSLError, + None, + 1, + ), + ( + 'SSL.Error: Fatal SSL Alert', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.Error([(-1, 'SSL routines', 'generic alert')])], + errors.FatalSSLAlert, + None, + 1, + ), +] + + +@pytest.mark.parametrize( + ( + 'test_case_name', + 'call_method', + 'mock_target', + 'side_effects', + 'expected_exception', + 'expected_result', + 'expected_call_count', + ), + test_cases, + ids=[case[0] for case in test_cases], # Extract names at module level +) +def test_safe_call_coverage( # pylint: disable=too-many-positional-arguments + test_case_name, + call_method, + mock_target, + side_effects, + expected_exception, + expected_result, + expected_call_count, + mocker, +): + """Test all critical success, retry, and error-mapping paths in _safe_call.""" + # Create mocks + mock_ssl_conn = mocker.MagicMock() + mock_ssl_conn.__class__ = ssl_conn_type + getattr(mock_ssl_conn, mock_target).side_effect = side_effects + + mock_raw_socket = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + mock_context = mocker.MagicMock() + + # Configure time mocks for the timeout case + mock_sleep = mocker.patch('time.sleep') + + if expected_exception is socket.timeout: + start_time = 1000.0 + mocker.patch('time.time').side_effect = [ + start_time, # Start time + start_time + 0.05, # First check (0.05 < 0.1) + start_time + 0.15, # Second check (0.15 > 0.1) -> Timeout + ] + else: + mocker.patch('time.time').return_value = 1000.0 + + # Create TLSSocket + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Call the method + call_func = getattr(tls_socket, call_method) + if expected_exception: + # Test Case expects an exception + with pytest.raises(expected_exception): + # Pass dummy args for recv/send + call_func(1024) if call_method == 'recv' else call_func( + b'test data', + ) + + else: + # Test Case expects a successful return + actual_result = ( + call_func(1024) + if call_method == 'recv' + else call_func(b'test data') + ) + assert actual_result == expected_result + + # Final check on call count + assert ( + getattr(mock_ssl_conn, mock_target).call_count == expected_call_count + ) + + # Ensure sleep was called if errors occurred that require retries + if expected_exception is not None and expected_exception not in { + socket.error, + errors.NoSSLError, + errors.FatalSSLAlert, + }: + assert mock_sleep.called + + +def test_tlssocket_is_readable(mocker): + """Test that TLSSocket properly declares itself as readable.""" + mock_ssl_conn = mocker.MagicMock(spec=SSL.Connection) + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + assert isinstance(tls_socket, io.RawIOBase) + + # The real test - if these are False, our methods aren't being called + if tls_socket.readable() is False: + pytest.fail( + 'TLSSocket.readable() is not working - check if changes were applied to the actual file', + ) + + assert tls_socket.readable() is True + assert hasattr(tls_socket, 'readinto') + + +@pytest.mark.parametrize( + ( + 'io_method', + 'error_class', + 'call_target', + 'test_input', + 'expected_output', + ), + ( + ('readinto', SSL.WantReadError, 'recv', 100, b'Hello World'), + ('write', SSL.WantWriteError, 'send', b'Test data', 9), + ), + ids=['readinto_WantReadError', 'write_WantWriteError'], +) +def test_tlssocket_io_handles_want_errors( # pylint: disable=too-many-positional-arguments + mocker, + io_method, + error_class, + call_target, + test_input, + expected_output, +): + """Test that TLSSocket I/O methods handle WantRead/WantWrite errors with retry.""" + # Setup mocks + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + + # Track call count + call_count = {'count': 0} + + # Configure mock to raise error first, then succeed + def io_operation(*args, **kwargs): + call_count['count'] += 1 + + if call_count['count'] == 1: + raise error_class() + if call_target == 'recv': + return expected_output + # send + return len(args[0]) if args else expected_output + + # Attach mock to the appropriate method + setattr(mock_ssl_conn, call_target, io_operation) + + # Create TLSSocket + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Execute the I/O operation + if io_method == 'readinto': + buffer = bytearray(test_input) + bytes_read = tls_socket.readinto(buffer) + assert bytes(buffer[:bytes_read]) == expected_output + assert bytes_read == len(expected_output) + else: # write + bytes_written = tls_socket.write(test_input) + assert bytes_written == expected_output + + # Verify retry happened + assert call_count['count'] == 2, ( + f'Should have retried after {error_class.__name__}' + ) + + +def test_tlssocket_readinto_handles_syscallerror_eof(mocker): + """Test that TLSSocket.readinto() handles SysCallError with Unexpected EOF.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + # SysCallError with Unexpected EOF should return empty + mock_ssl_conn.recv = mocker.MagicMock( + side_effect=SSL.SysCallError(-1, 'Unexpected EOF'), + ) + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + buffer = bytearray(100) + bytes_read = tls_socket.readinto(buffer) + + assert bytes_read == 0, 'Unexpected EOF should return 0 bytes (EOF)' + + +def test_tlssocket_with_buffered_reader(mocker): + """Test that TLSSocket works correctly with :class:`io.BufferedReader`.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + call_count = {'count': 0} + + def recv_with_retry(size): + call_count['count'] += 1 + if call_count['count'] == 1: + raise SSL.WantReadError + if call_count['count'] == 2: + return b'Data from BufferedReader' + # Return empty to signal EOF for subsequent calls + return b'' + + mock_ssl_conn.recv = recv_with_retry + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Create BufferedReader with TLSSocket as the raw I/O + reader = io.BufferedReader(tls_socket, buffer_size=8192) + + # Read through BufferedReader + read_data = reader.read(100) + + assert read_data == b'Data from BufferedReader' + assert call_count['count'] >= 2, 'Should have retried after WantReadError' + + +def test_tlssocket_timeout_on_repeated_errors(mocker): + """Test that repeated SSL errors eventually timeout.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + # Always raise WantReadError + mock_ssl_conn.recv = mocker.MagicMock(side_effect=SSL.WantReadError()) + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + tls_socket.ssl_retry_max = 0.05 # Short timeout for testing + + buffer = bytearray(100) + + with pytest.raises(socket.timeout): + tls_socket.readinto(buffer) + + +def test_tlssocket_sock_shutdown(mocker): + """Test that sock_shutdown calls the raw socket's shutdown method.""" + mock_ssl_conn = mocker.MagicMock() + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Call sock_shutdown + tls_socket.sock_shutdown(socket.SHUT_RDWR) + + # Verify it called the raw socket's shutdown, not the SSL connection's + mock_raw_socket.shutdown.assert_called_once_with(socket.SHUT_RDWR) + mock_ssl_conn.shutdown.assert_not_called() + + +@pytest.mark.usefixtures('mocker') +def test_wrap_with_pyopenssl_ssl_connection_creation_fails( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that SSL.Connection creation error raises FatalSSLAlert.""" + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter.get_context() + + # Use a real socket + server_sock, client_sock = socket.socketpair() + + try: + # Patch where it's imported IN THE MODULE + call_count = {'count': 0} + + def failing_connection(context, sock): + call_count['count'] += 1 + raise SSL.Error([(-1, 'SSL routines', 'initialization failed')]) + + # Patch ssl_conn_type, not SSL.Connection + mocker.patch( + 'cheroot.ssl.pyopenssl.ssl_conn_type', + side_effect=failing_connection, + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating pyOpenSSL connection', + ): + adapter._wrap_with_pyopenssl(server_sock) + + finally: + server_sock.close() + client_sock.close() + + +@pytest.mark.parametrize( + ('error_class', 'error_args', 'expected_exception', 'match_text'), + ( + ( + SSL.WantReadError, + None, + None, # Should retry, not raise + None, + ), + ( + OSError, + (errno.ECONNRESET, 'Connection reset by peer'), + errors.FatalSSLAlert, + 'TCP error during handshake', + ), + ( + SSL.Error, + [(-1, 'SSL routines', 'http request')], + errors.NoSSLError, + 'Client sent plain HTTP request', + ), + ( + SSL.Error, + [(-1, 'SSL routines', 'certificate verify failed')], + errors.FatalSSLAlert, + 'Fatal SSL error during handshake', + ), + ( + SSL.ZeroReturnError, # NEW + None, + errors.NoSSLError, + 'Peer closed connection during handshake', + ), + ), + ids=[ + 'WantReadError_retry', + 'OSError_ECONNRESET', + 'SSL_http_request', + 'SSL_cert_failed', + 'ZeroReturnError', + ], +) +def test_handshake_error_handling( # pylint: disable=too-many-positional-arguments + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, + error_class, + error_args, + expected_exception, + match_text, +): + """Test various error conditions during SSL handshake.""" + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter.get_context() + + server_sock, client_sock = socket.socketpair() + + try: + original_connection = SSL.Connection + handshake_attempts = {'count': 0} + + def patched_connection(context, sock): + conn = original_connection(context, sock) + + def failing_handshake(): + handshake_attempts['count'] += 1 + if handshake_attempts['count'] == 1: + if error_args: + raise ( + error_class(*error_args) + if isinstance(error_args, tuple) + else error_class(error_args) + ) + raise error_class() + # Second attempt for retry cases + raise SSL.Error([(-1, 'SSL routines', 'test completed')]) + + conn.do_handshake = failing_handshake + return conn + + mocker.patch( + 'cheroot.ssl.pyopenssl.ssl_conn_type', + side_effect=patched_connection, + ) + mocker.patch( + 'select.select', + return_value=([server_sock.fileno()], [], []), + ) + + if expected_exception: + with pytest.raises(expected_exception, match=match_text): + adapter._wrap_with_pyopenssl(server_sock) + else: + with suppress(Exception): + adapter._wrap_with_pyopenssl(server_sock) + assert handshake_attempts['count'] >= 2 + + finally: + server_sock.close() + client_sock.close() + + +def test_streamreader_with_tls_socket(mocker): + """Test StreamReader works correctly with TLSSocket.""" + # Setup a TLSSocket instance + mock_ssl_socket = mocker.MagicMock() + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mock_raw_socket, + context=mock_context, + ) + + # Create StreamReader with TLSSocket + buffered_reader = StreamReader(mock_tls_socket, bufsize=4096) + + # Assert it's a StreamReader instance + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + + # Verify TLSSocket was used directly (not wrapped with SocketIO) + # The _wrapped attribute would be the TLSSocket itself + assert buffered_reader.raw is mock_tls_socket + + +def test_streamwriter_with_regular_socket(mocker): + """Test StreamWriter works correctly with regular socket.""" + regular_sock = mocker.MagicMock() + + # Mock SocketIO to verify it's called for regular sockets + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + writer = StreamWriter(regular_sock, bufsize=1024) + + # Verify SocketIO was called to wrap the regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'wb') + + # Verify it's a StreamWriter instance + assert isinstance(writer, StreamWriter) + assert writer.bytes_written == 0 diff --git a/cheroot/test/test_makefile.py b/cheroot/test/test_makefile.py index d65d4ea268..c6fba66405 100644 --- a/cheroot/test/test_makefile.py +++ b/cheroot/test/test_makefile.py @@ -1,45 +1,46 @@ """Tests for :py:mod:`cheroot.makefile`.""" -from cheroot import makefile +import io +from cheroot.makefile import StreamReader, StreamWriter -class MockSocket: - """A mock socket.""" + +class MockSocket(io.RawIOBase): + """A mock socket for testing stream I/O.""" def __init__(self): - """Initialize :py:class:`MockSocket`.""" + """Initialize MockSocket.""" + super().__init__() self.messages = [] - def recv_into(self, buf): - """Simulate ``recv_into`` for Python 3.""" - if not self.messages: - return 0 - msg = self.messages.pop(0) - for index, byte in enumerate(msg): - buf[index] = byte - return len(msg) + def readable(self): + """Return True - supports reading.""" + return True - def recv(self, size): - """Simulate ``recv`` for Python 2.""" - try: - return self.messages.pop(0) - except IndexError: - return '' + def writable(self): + """Return True - supports writing.""" + return True - def send(self, val): - """Simulate a send.""" - return len(val) + def readinto(self, buf): + """Read data into buffer.""" + if not self.messages: + return 0 # EOF - def _decref_socketios(self): - """Emulate socket I/O reference decrement.""" - # Ref: https://github.com/cherrypy/cheroot/issues/734 + msg = self.messages.pop(0) + num_bytes = min(len(msg), len(buf)) + buf[:num_bytes] = msg[:num_bytes] # noqa: WPS362 + return num_bytes + + def write(self, data): + """Write data (returns length written).""" + return len(data) def test_bytes_read(): """Reader should capture bytes read.""" sock = MockSocket() sock.messages.append(b'foo') - rfile = makefile.MakeFile(sock, 'r') + rfile = StreamReader(sock) rfile.read() assert rfile.bytes_read == 3 @@ -47,7 +48,6 @@ def test_bytes_read(): def test_bytes_written(): """Writer should capture bytes written.""" sock = MockSocket() - sock.messages.append(b'foo') - wfile = makefile.MakeFile(sock, 'w') + wfile = StreamWriter(sock) wfile.write(b'bar') assert wfile.bytes_written == 3 diff --git a/cheroot/test/test_ssl.py b/cheroot/test/test_ssl.py index 115237f4f7..152267a5f9 100644 --- a/cheroot/test/test_ssl.py +++ b/cheroot/test/test_ssl.py @@ -32,9 +32,7 @@ IS_LINUX, IS_MACOS, IS_PYPY, - IS_SOLARIS, IS_WINDOWS, - SYS_PLATFORM, bton, ntob, ntou, @@ -723,36 +721,6 @@ def test_http_over_https_error( tls_certificate_private_key_pem_path, ): """Ensure that connecting over HTTP to HTTPS port is handled.""" - # disable some flaky tests - # https://github.com/cherrypy/cheroot/issues/225 - issue_225 = IS_MACOS and adapter_type == 'builtin' - if issue_225: - pytest.xfail('Test fails in Travis-CI') - - if IS_LINUX: - expected_error_code, expected_error_text = ( - 104, - 'Connection reset by peer', - ) - elif IS_MACOS: - expected_error_code, expected_error_text = ( - 54, - 'Connection reset by peer', - ) - elif IS_SOLARIS: - expected_error_code, expected_error_text = ( - None, - 'Remote end closed connection without response', - ) - elif IS_WINDOWS: - expected_error_code, expected_error_text = ( - 10054, - 'An existing connection was forcibly closed by the remote host', - ) - else: - expected_error_code, expected_error_text = None, None - pytest.skip(f'{SYS_PLATFORM} is unsupported') # pragma: no cover - tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) tls_adapter = tls_adapter_cls( tls_certificate_chain_pem_path, @@ -774,31 +742,15 @@ def test_http_over_https_error( if ip_addr is ANY_INTERFACE_IPV6: fqdn = '[{fqdn}]'.format(**locals()) - expect_fallback_response_over_plain_http = adapter_type == 'pyopenssl' - if expect_fallback_response_over_plain_http: - resp = requests.get( - f'http://{fqdn!s}:{port!s}/', - timeout=http_request_timeout, - ) - assert resp.status_code == 400 - assert resp.text == ( - 'The client sent a plain HTTP request, ' - 'but this server only speaks HTTPS on this port.' - ) - return - - with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: - requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL - f'http://{fqdn!s}:{port!s}/', - timeout=http_request_timeout, - ) - - underlying_error = ssl_err.value.args[0].args[-1] - err_text = str(underlying_error) - assert underlying_error.errno == expected_error_code, ( - 'The underlying error is {underlying_error!r}'.format(**locals()) + resp = requests.get( + f'http://{fqdn!s}:{port!s}/', + timeout=http_request_timeout, + ) + assert resp.status_code == 400 + assert resp.text == ( + 'The client sent a plain HTTP request, ' + 'but this server only speaks HTTPS on this port.' ) - assert expected_error_text in err_text @pytest.mark.parametrize('adapter_type', ('builtin', 'pyopenssl')) @@ -851,6 +803,8 @@ def test_ssl_adapters_with_private_key_password( ).bind_addr, ) + time.sleep(2) # allow time for the server to start + resp = requests.get( f'https://{interface!s}:{port!s}/', timeout=http_request_timeout, diff --git a/cheroot/wsgi.py b/cheroot/wsgi.py index 4ac4b16764..154818f484 100644 --- a/cheroot/wsgi.py +++ b/cheroot/wsgi.py @@ -228,6 +228,8 @@ def write(self, chunk): # Dang. We have probably already sent data. Truncate the chunk # to fit (so the client doesn't hang) and raise an error later. chunk = chunk[:rbo] + # Update chunklen to match the truncated size + chunklen = len(chunk) self.req.ensure_headers_sent() diff --git a/docs/changelog-fragments.d/799.contrib.rst b/docs/changelog-fragments.d/799.contrib.rst new file mode 100644 index 0000000000..f2068194fb --- /dev/null +++ b/docs/changelog-fragments.d/799.contrib.rst @@ -0,0 +1,3 @@ +Added :py:class:`cheroot.ssl.tls_socket.TLSSocket` to provide uniform SSL/TLS handling + +-- by :user:`julianz-`. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1e1022c82f..7d23c08f6e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2,6 +2,7 @@ AppVeyor Args backend backports +bufsize bugfixes builtin b'xb @@ -46,12 +47,14 @@ preconfigure py pytest pythonic +pyOpenSSL readonly rebase Refactor Refactored refactoring refactorings +seekable Sep sep signalling @@ -75,6 +78,7 @@ tuples unbuffered unclosed unfilterable +unhandled unregister unregisters uptime diff --git a/stubtest_allowlist.txt b/stubtest_allowlist.txt index cc6e94a5a5..b2ab351a6a 100644 --- a/stubtest_allowlist.txt +++ b/stubtest_allowlist.txt @@ -1,40 +1,3 @@ -# generated members by metaclass -cheroot.ssl.pyopenssl.SSLConnection.accept -cheroot.ssl.pyopenssl.SSLConnection.bind -cheroot.ssl.pyopenssl.SSLConnection.close -cheroot.ssl.pyopenssl.SSLConnection.connect -cheroot.ssl.pyopenssl.SSLConnection.connect_ex -cheroot.ssl.pyopenssl.SSLConnection.family -cheroot.ssl.pyopenssl.SSLConnection.fileno -cheroot.ssl.pyopenssl.SSLConnection.get_app_data -cheroot.ssl.pyopenssl.SSLConnection.get_cipher_list -cheroot.ssl.pyopenssl.SSLConnection.get_context -cheroot.ssl.pyopenssl.SSLConnection.get_peer_certificate -cheroot.ssl.pyopenssl.SSLConnection.getpeername -cheroot.ssl.pyopenssl.SSLConnection.getsockname -cheroot.ssl.pyopenssl.SSLConnection.getsockopt -cheroot.ssl.pyopenssl.SSLConnection.gettimeout -cheroot.ssl.pyopenssl.SSLConnection.listen -cheroot.ssl.pyopenssl.SSLConnection.makefile -cheroot.ssl.pyopenssl.SSLConnection.pending -cheroot.ssl.pyopenssl.SSLConnection.read -cheroot.ssl.pyopenssl.SSLConnection.recv -cheroot.ssl.pyopenssl.SSLConnection.renegotiate -cheroot.ssl.pyopenssl.SSLConnection.send -cheroot.ssl.pyopenssl.SSLConnection.sendall -cheroot.ssl.pyopenssl.SSLConnection.set_accept_state -cheroot.ssl.pyopenssl.SSLConnection.set_app_data -cheroot.ssl.pyopenssl.SSLConnection.set_connect_state -cheroot.ssl.pyopenssl.SSLConnection.setblocking -cheroot.ssl.pyopenssl.SSLConnection.setsockopt -cheroot.ssl.pyopenssl.SSLConnection.settimeout -cheroot.ssl.pyopenssl.SSLConnection.shutdown -cheroot.ssl.pyopenssl.SSLConnection.sock_shutdown -cheroot.ssl.pyopenssl.SSLConnection.state_string -cheroot.ssl.pyopenssl.SSLConnection.want_read -cheroot.ssl.pyopenssl.SSLConnection.want_write -cheroot.ssl.pyopenssl.SSLConnection.write - # suppress is both a function and class cheroot._compat.suppress