diff --git a/.flake8 b/.flake8 index aca9b2fffe..aa2bf47fe3 100644 --- a/.flake8 +++ b/.flake8 @@ -124,16 +124,17 @@ per-file-ignores = cheroot/__main__.py: WPS130 cheroot/_compat.py: DAR101, DAR201, DAR301, DAR401, I003, RST304, WPS100, WPS111, WPS123, WPS226, WPS229, WPS420, WPS422, WPS432, WPS504, WPS505 cheroot/cli.py: DAR101, DAR201, DAR401, I001, I004, I005, WPS100, WPS110, WPS120, WPS130, WPS202, WPS226, WPS229, WPS338, WPS420, WPS421 - cheroot/connections.py: DAR101, DAR201, DAR301, DAR401, I001, I003, I004, I005, RST304, S104, WPS100, WPS110, WPS111, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS212, WPS214, WPS220, WPS229, WPS231, WPS301, WPS324, WPS338, WPS420, WPS421, WPS422, WPS432, WPS501, WPS504, WPS505 + cheroot/connections.py: DAR101, DAR201, DAR301, DAR401, I001, I003, I004, I005, RST304, S104, WPS100, WPS110, WPS111, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS212, WPS214, WPS220, WPS229, WPS231, WPS237, WPS301, WPS324, WPS338, WPS420, WPS421, WPS422, WPS432, WPS501, WPS504, WPS505 cheroot/errors.py: DAR101, DAR201, I003, RST304, WPS111, WPS121, WPS422 - cheroot/makefile.py: DAR101, DAR201, DAR401, E800, I003, I004, N801, N802, S101, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS123, WPS130, WPS204, WPS210, WPS212, WPS213, WPS220, WPS229, WPS231, WPS232, WPS338, WPS420, WPS422, WPS429, WPS431, WPS504, WPS604, WPS606 + cheroot/makefile.py: DAR101, DAR201, DAR401, E800, I003, I004, N801, N802, S101, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS123, WPS130, WPS204, WPS210, WPS212, WPS213, WPS226, WPS220, WPS229, WPS231, WPS232, WPS338, WPS420, WPS422, WPS429, WPS431, WPS504, WPS604, WPS606 cheroot/server.py: DAR003, DAR101, DAR201, DAR202, DAR301, DAR401, E800, I001, I003, I004, I005, N806, RST201, RST301, RST303, RST304, WPS100, WPS110, WPS111, WPS115, WPS120, WPS121, WPS122, WPS130, WPS132, WPS201, WPS202, WPS204, WPS210, WPS211, WPS212, WPS213, WPS214, WPS220, WPS221, WPS225, WPS226, WPS229, WPS230, WPS231, WPS236, WPS237, WPS238, WPS301, WPS338, WPS342, WPS410, WPS420, WPS421, WPS422, WPS429, WPS432, WPS504, WPS505, WPS601, WPS602, WPS608, WPS617 - cheroot/ssl/builtin.py: DAR101, DAR201, DAR401, I001, I003, N806, RST304, WPS110, WPS111, WPS115, WPS117, WPS120, WPS121, WPS122, WPS130, WPS201, WPS210, WPS214, WPS229, WPS231, WPS338, WPS422, WPS501, WPS505, WPS529, WPS608, WPS612 - cheroot/ssl/pyopenssl.py: C815, DAR101, DAR201, DAR401, I001, I003, I005, N801, N804, RST304, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS130, WPS210, WPS220, WPS221, WPS225, WPS229, WPS231, WPS238, WPS301, WPS335, WPS338, WPS420, WPS422, WPS430, WPS432, WPS501, WPS504, WPS505, WPS601, WPS608, WPS615 - cheroot/test/conftest.py: DAR101, DAR201, DAR301, I001, I003, I005, WPS100, WPS130, WPS325, WPS354, WPS420, WPS422, WPS430, WPS457 + cheroot/ssl/builtin.py: DAR101, DAR201, DAR401, I001, I003, N806, RST304, WPS110, WPS111, WPS115, WPS117, WPS120, WPS121, WPS122, WPS130, WPS201, WPS204, WPS210, WPS220, WPS214, WPS226, WPS229, WPS231, WPS338, WPS421, WPS422, WPS501, WPS505, WPS529, WPS608, WPS612 + cheroot/ssl/pyopenssl.py: C815, DAR101, DAR201, DAR401, I001, I003, I005, N801, N804, RST304, WPS100, WPS110, WPS111, WPS117, WPS120, WPS121, WPS122, WPS130, WPS210, WPS220, WPS221, WPS225, WPS229, WPS231, WPS238, WPS301, WPS335, WPS338, WPS420, WPS422, WPS430, WPS432, WPS501, WPS504, WPS505, WPS601, WPS608, WPS615 + cheroot/ssl/tls_socket.py: DAR101, DAR201, DAR401, WPS110, WPS122, WPS210, WPS212, WPS214, WPS220, WPS225, WPS226, WPS229, WPS231, WPS238, WPS338, WPS362, WPS407 + cheroot/test/conftest.py: DAR101, DAR201, DAR301, I001, I003, I005, WPS100, WPS130, WPS202, WPS325, WPS354, WPS420, WPS422, WPS430, WPS457 cheroot/test/helper.py: DAR101, DAR201, DAR401, I001, I003, I004, N802, WPS110, WPS111, WPS121, WPS201, WPS220, WPS231, WPS301, WPS414, WPS421, WPS422, WPS505 cheroot/test/test_cli.py: DAR101, DAR201, I001, I005, N802, S101, S108, WPS110, WPS421, WPS431, WPS473 - cheroot/test/test_makefile.py: DAR101, DAR201, I004, RST304, S101, WPS110, WPS122 + cheroot/test/test_makefile.py: DAR101, DAR201, I004, RST304, S101, WPS110, WPS122, WPS362 cheroot/test/test_wsgi.py: DAR101, DAR301, I001, I004, S101, WPS110, WPS111, WPS117, WPS118, WPS121, WPS210, WPS421, WPS430, WPS432, WPS441, WPS509 cheroot/test/test_core.py: C815, DAR101, DAR201, DAR401, I003, I004, N805, N806, S101, WPS110, WPS111, WPS114, WPS121, WPS202, WPS204, WPS226, WPS229, WPS324, WPS421, WPS422, WPS432, WPS602 cheroot/test/test_dispatch.py: DAR101, DAR201, S101, WPS111, WPS121, WPS422, WPS430 @@ -144,7 +145,9 @@ per-file-ignores = cheroot/testing.py: C815, DAR101, DAR201, DAR301, I001, I003, S104, WPS100, WPS202, WPS211, WPS229, WPS301, WPS414, WPS420, WPS422, WPS430 cheroot/workers/threadpool.py: DAR101, DAR201, E800, I001, I003, I004, RST201, RST203, RST301, WPS100, WPS110, WPS111, WPS121, WPS122, WPS210, WPS211, WPS214, WPS220, WPS229, WPS230, WPS231, WPS335, WPS338, WPS362, WPS363, WPS410, WPS414, WPS420, WPS422, WPS432, WPS501, WPS505, WPS601, WPS602, WPS617 cheroot/wsgi.py: DAR101, DAR201, DAR401, I001, I003, I005, N801, RST201, RST301, WPS100, WPS110, WPS111, WPS114, WPS121, WPS122, WPS130, WPS210, WPS211, WPS226, WPS229, WPS231, WPS338, WPS420, WPS421, WPS422, WPS430, WPS501, WPS504, WPS602, WPS608 - cheroot/ssl/__init__.py: DAR101, DAR201, I003, WPS412, WPS422 + cheroot/ssl/__init__.py: DAR101, DAR201, I003, WPS210, WPS412, WPS422 + cheroot/test/ssl/test_ssl_builtin.py: DAR101, DAR201, I003, WPS118, WPS201, WPS202, WPS210, WPS213, WPS218, WPS211, WPS226, WPS229, WPS231, WPS243, WPS412, WPS420, WPS422, WPS430, WPS505 + cheroot/test/ssl/test_ssl_pyopenssl.py: DAR101, DAR201, I003, WPS118, WPS201, WPS202, WPS204, WPS210, WPS220, WPS213, WPS218, WPS211, WPS226, WPS229, WPS231, WPS243, WPS412, WPS420, WPS422, WPS430, WPS432, WPS435, WPS505 cheroot/test/_pytest_plugin.py: DAR101, I003, I004, WPS422 cheroot/test/test__compat.py: DAR101, I001, I003, I005, WPS116, WPS226, WPS422, S101 cheroot/test/test_errors.py: DAR101, WPS509, S101 diff --git a/.mypy.ini b/.mypy.ini index fa1389a44b..8e9badbc25 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.8 +python_version = 3.9 color_output = true error_summary = true files = diff --git a/cheroot/connections.py b/cheroot/connections.py index 04df0d42a3..8a66f799f0 100644 --- a/cheroot/connections.py +++ b/cheroot/connections.py @@ -1,6 +1,5 @@ """Utilities to manage open connections.""" -import io import os import selectors import socket @@ -10,7 +9,6 @@ from . import errors from ._compat import IS_WINDOWS -from .makefile import MakeFile try: @@ -293,50 +291,22 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME if hasattr(s, 'settimeout'): s.settimeout(self.server.timeout) - mf = MakeFile ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server if self.server.ssl_adapter is not None: try: s, ssl_env = self.server.ssl_adapter.wrap(s) - except errors.FatalSSLAlert as tls_connection_drop_error: - self.server.error_log( - f'Client {addr!s} lost — peer dropped the TLS ' - 'connection suddenly, during handshake: ' - f'{tls_connection_drop_error!s}', - ) - return None - except errors.NoSSLError as http_over_https_err: + except errors.FatalSSLAlert as tls_connection_error: self.server.error_log( - f'Client {addr!s} attempted to speak plain HTTP into ' - 'a TCP connection configured for TLS-only traffic — ' - 'trying to send back a plain HTTP error response: ' - f'{http_over_https_err!s}', + f'Failed to establish SSL connection with {addr!s}: ' + f'{tls_connection_error!s}', ) - msg = ( - 'The client sent a plain HTTP request, but ' - 'this server only speaks HTTPS on this port.' - ) - buf = [ - '%s 400 Bad Request\r\n' % self.server.protocol, - 'Content-Length: %s\r\n' % len(msg), - 'Content-Type: text/plain\r\n\r\n', - msg, - ] - - wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE) - try: - wfile.write(''.join(buf).encode('ISO-8859-1')) - except OSError as ex: - if ex.args[0] not in errors.socket_errors_to_ignore: - raise return None - mf = self.server.ssl_adapter.makefile - # Re-apply our timeout since we may have a new socket object - if hasattr(s, 'settimeout'): - s.settimeout(self.server.timeout) + except errors.NoSSLError: + return self._send_bad_request_plain_http_error(s, addr) - conn = self.server.ConnectionClass(self.server, s, mf) + conn = self.server.ConnectionClass(self.server, s) if not isinstance(self.server.bind_addr, (str, bytes)): # optional values @@ -381,6 +351,43 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME return None raise + def _send_bad_request_plain_http_error(self, sock, addr): + """Send Bad Request 400 response, and close the socket.""" + self.server.error_log( + f'Client {addr!s} attempted to speak plain HTTP into ' + 'a TCP connection configured for TLS-only traffic — ' + 'Sending 400 Bad Request.', + ) + + msg = ( + 'The client sent a plain HTTP request, but this server ' + 'only speaks HTTPS on this port.' + ) + + response_parts = [ + f'{self.server.protocol} 400 Bad Request\r\n', + 'Content-Type: text/plain\r\n', + f'Content-Length: {len(msg)}\r\n', + 'Connection: close\r\n', + '\r\n', + msg, + ] + response_bytes = ''.join(response_parts).encode('ISO-8859-1') + + try: + # Handle both raw sockets and SSL connections + if hasattr(sock, 'sendall'): + sock.sendall(response_bytes) + else: + # Fallback for older PyOpenSSL or SSL objects + sock.send(response_bytes) + sock.shutdown(socket.SHUT_WR) + except OSError as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + + sock.close() + def close(self): """Close all monitored connections.""" for _, conn in self._selector.connections: diff --git a/cheroot/makefile.py b/cheroot/makefile.py index f5780a1ede..b1981d34d5 100644 --- a/cheroot/makefile.py +++ b/cheroot/makefile.py @@ -2,6 +2,7 @@ # prefer slower Python-based io module import _pyio as io +import io as stdlib_io import socket @@ -38,9 +39,16 @@ def _flush_unlocked(self): class StreamReader(io.BufferedReader): """Socket stream reader.""" - def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): - """Initialize socket stream reader.""" - super().__init__(socket.SocketIO(sock, mode), bufsize) + def __init__(self, sock, bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize with socket or raw IO object.""" + # If already a RawIOBase (like TLSSocket), use directly + if isinstance(sock, (io.RawIOBase, stdlib_io.RawIOBase)): + raw_io = sock + else: + # Wrap raw socket with SocketIO + raw_io = socket.SocketIO(sock, 'rb') + + super().__init__(raw_io, bufsize) self.bytes_read = 0 def read(self, *args, **kwargs): @@ -57,9 +65,16 @@ def has_data(self): class StreamWriter(BufferedWriter): """Socket stream writer.""" - def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE): - """Initialize socket stream writer.""" - super().__init__(socket.SocketIO(sock, mode), bufsize) + def __init__(self, sock, bufsize=io.DEFAULT_BUFFER_SIZE): + """Initialize with socket or raw IO object.""" + # If already a RawIOBase (like TLSSocket), use directly + if isinstance(sock, (io.RawIOBase, stdlib_io.RawIOBase)): + raw_io = sock + else: + # Wrap raw socket with SocketIO + raw_io = socket.SocketIO(sock, 'wb') + + super().__init__(raw_io, bufsize) self.bytes_written = 0 def write(self, val, *args, **kwargs): @@ -67,9 +82,3 @@ def write(self, val, *args, **kwargs): res = super().write(val, *args, **kwargs) self.bytes_written += len(val) return res - - -def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): - """File object attached to a socket object.""" - cls = StreamReader if 'r' in mode else StreamWriter - return cls(sock, mode, bufsize) diff --git a/cheroot/makefile.pyi b/cheroot/makefile.pyi index 3f5ea2756b..592ef2bf91 100644 --- a/cheroot/makefile.pyi +++ b/cheroot/makefile.pyi @@ -7,13 +7,11 @@ class BufferedWriter(io.BufferedWriter): class StreamReader(io.BufferedReader): bytes_read: int - def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ... + def __init__(self, sock, bufsize=...) -> None: ... def read(self, *args, **kwargs): ... def has_data(self): ... class StreamWriter(BufferedWriter): bytes_written: int - def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ... + def __init__(self, sock, bufsize=...) -> None: ... def write(self, val, *args, **kwargs): ... - -def MakeFile(sock, mode: str = ..., bufsize=...): ... diff --git a/cheroot/server.py b/cheroot/server.py index 9a288539c7..cb79316554 100644 --- a/cheroot/server.py +++ b/cheroot/server.py @@ -83,7 +83,7 @@ from . import __version__, connections, errors from ._compat import IS_PPC, bton -from .makefile import MakeFile, StreamWriter +from .makefile import StreamReader, StreamWriter from .workers import threadpool @@ -1275,19 +1275,18 @@ class HTTPConnection: # Fields set by ConnectionManager. last_used = None - def __init__(self, server, sock, makefile=MakeFile): + def __init__(self, server, sock): """Initialize HTTPConnection instance. Args: server (HTTPServer): web server object receiving this request sock (socket._socketobject): the raw socket object (usually TCP) for this connection - makefile (file): a fileobject class for reading from the socket """ self.server = server self.socket = sock - self.rfile = makefile(sock, 'rb', self.rbufsize) - self.wfile = makefile(sock, 'wb', self.wbufsize) + self.rfile = StreamReader(sock, self.rbufsize) + self.wfile = StreamWriter(sock, self.wbufsize) self.requests_seen = 0 self.peercreds_enabled = self.server.peercreds_enabled @@ -1363,7 +1362,7 @@ def _handle_no_ssl(self, req): except AttributeError: # self.socket is of OpenSSL.SSL.Connection type resp_sock = self.socket._socket - self.wfile = StreamWriter(resp_sock, 'wb', self.wbufsize) + self.wfile = StreamWriter(resp_sock, self.wbufsize) msg = ( 'The client sent a plain HTTP request, but ' 'this server only speaks HTTPS on this port.' diff --git a/cheroot/server.pyi b/cheroot/server.pyi index c5c0f517f6..569ec94225 100644 --- a/cheroot/server.pyi +++ b/cheroot/server.pyi @@ -112,7 +112,7 @@ class HTTPConnection: rfile: Any wfile: Any requests_seen: int - def __init__(self, server, sock, makefile=...) -> None: ... + def __init__(self, server, sock) -> None: ... def communicate(self): ... linger: bool def close(self) -> None: ... diff --git a/cheroot/ssl/__init__.py b/cheroot/ssl/__init__.py index c0072c0557..ad4ff40f5e 100644 --- a/cheroot/ssl/__init__.py +++ b/cheroot/ssl/__init__.py @@ -1,16 +1,162 @@ -"""Implementation of the SSL adapter base interface.""" +"""Implementation of the SSL adapter base interface. + +.. spelling:: + dn +""" from abc import ABC, abstractmethod +from contextlib import suppress + + +DN_SEPARATOR = '/' # Distinguished Name separator +UTF8_ENCODING = 'utf-8' + + +def _parse_dn_components(components, key_prefix, dn_type): + """ + Parse Distinguished Name components into environ dict. + + Args: + components: Iterable of (key, value) tuples + key_prefix: 'SSL_CLIENT' or 'SSL_SERVER' + dn_type: 'S' for subject or 'I' for issuer + + Returns: + dict: ``DN`` and ``CN`` environment variables + """ + env = {} + dn_parts = [] + + for key, attr_value in components: + dn_parts.append(f'{key}={attr_value}') + if key in {'CN', 'commonName'}: + env[f'{key_prefix}_{dn_type}_DN_CN'] = attr_value + + if dn_parts: + dn_string = DN_SEPARATOR.join(dn_parts) + env[f'{key_prefix}_{dn_type}_DN'] = f'{DN_SEPARATOR}{dn_string}' + + return env + + +def parse_pyopenssl_cert_to_environ(cert, key_prefix): + """Parse a pyOpenSSL X509 certificate into WSGI environ dict.""" + env = {} + if not cert: + return env + + # Subject + subject = cert.get_subject() + if subject: + components = [ + (key.decode(UTF8_ENCODING), attr_value.decode(UTF8_ENCODING)) + for key, attr_value in subject.get_components() + ] + env.update(_parse_dn_components(components, key_prefix, 'S')) + + # Issuer + issuer = cert.get_issuer() + if issuer: + components = [ + (key.decode(UTF8_ENCODING), attr_value.decode(UTF8_ENCODING)) + for key, attr_value in issuer.get_components() + ] + env.update(_parse_dn_components(components, key_prefix, 'I')) + + # Version and Serial + env[f'{key_prefix}_M_VERSION'] = str(cert.get_version()) + env[f'{key_prefix}_M_SERIAL'] = str(cert.get_serial_number()) + + return env + + +def parse_x509_cert_to_environ(cert, key_prefix): + """Parse a cryptography x509 certificate into environ dict.""" + env = {} + + # Subject + with suppress(Exception): + subject = cert.subject + components = [(attr.oid._name, attr.value) for attr in subject] + env.update(_parse_dn_components(components, key_prefix, 'S')) + + # Issuer + with suppress(Exception): + issuer = cert.issuer + components = [(attr.oid._name, attr.value) for attr in issuer] + env.update(_parse_dn_components(components, key_prefix, 'I')) + + # Version and Serial + with suppress(Exception): + env[f'{key_prefix}_M_VERSION'] = str(cert.version.value) + env[f'{key_prefix}_M_SERIAL'] = str(cert.serial_number) + + return env + + +class SSLEnvironMixin: + """ + Mixin class providing methods for generating WSGI environment variables. + + This mixin handles GENERIC SSL environment variable generation that works + across all SSL implementations. Adapter-specific logic (like certificate + parsing) is delegated to subclass implementations. + """ + + def _get_core_tls_environ(self, conn): + """ + Add core TLS version and cipher info to the environment. + + This is generic and works for all SSL adapters since TLSSocket + provides a uniform get_cipher_info() interface. + """ + cipher_info = conn.get_cipher_info() + + # Early exit if no cipher info (not a secure connection) + if cipher_info is None: + return {'wsgi.url_scheme': 'http'} + + cipher_name, protocol, cipher_keysize = cipher_info + + return { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_PROTOCOL': protocol, + 'SSL_CIPHER': cipher_name, + 'SSL_CIPHER_EXPORT': '', + 'SSL_CIPHER_USEKEYSIZE': cipher_keysize, + 'SSL_CLIENT_VERIFY': 'NONE', + } + + def _get_server_cert_environ(self): + """ + Get server certificate info from the connection. + + MUST be overridden by subclasses to provide adapter-specific parsing. + Returns dict of SSL_SERVER_* environ variables. + + Default implementation returns empty dict. + """ + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """ + Add client certificate details to the environment. + + SHOULD be overridden by subclasses for adapter-specific handling. + Default implementation does nothing. + """ + return ssl_environ -class Adapter(ABC): +class Adapter(SSLEnvironMixin, ABC): """Base class for SSL driver library adapters. Required methods: * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` - * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> - socket file object`` + * ``_get_library_version_environ() -> dict`` + * ``_get_optional_environ(conn) -> dict`` """ @abstractmethod @@ -39,14 +185,75 @@ def bind(self, sock): @abstractmethod def wrap(self, sock): """Wrap and return the given socket, plus WSGI environ entries.""" - raise NotImplementedError # pragma: no cover + raise NotImplementedError + + def get_environ(self, conn): + """ + Return WSGI environ entries to be merged into each request. + + Unified implementation used by all subclasses. This orchestrates + the collection of SSL environment variables from various sources: + - Core TLS info (protocol, cipher) + - Library versions + - Optional fields (SNI, etc.) + - Session info + - Client certificate + - Server certificate + + Note: This returns only SSL-specific variables. General server + variables (``SERVER_NAME``, ``SERVER_PORT``, etc.) are added by + the Gateway when building the complete WSGI environ for each request. + """ + # 1. Handle basic TLS info + ssl_environ = self._get_core_tls_environ(conn) + if 'HTTPS' not in ssl_environ: + # Core TLS failed (returned 'http' env) + return ssl_environ + + # 2. Update with library-specific version strings + ssl_environ.update(self._get_library_version_environ()) + + # 3. Handle optional/platform-specific fields (SNI, compression) + ssl_environ.update(self._get_optional_environ(conn)) + + # 4. Handle Session ID + with suppress(AttributeError): + session = conn.get_session() + if session and hasattr(session, 'id'): + ssl_environ['SSL_SESSION_ID'] = session.id.hex() + + # 5. Handle Client certificate (adapter-specific) + ssl_environ = self._get_client_cert_environ(conn, ssl_environ) + + # 6. Server certificate (adapter-specific) + server_cert_info = self._get_server_cert_environ() + if server_cert_info: + ssl_environ.update(server_cert_info) + + return ssl_environ @abstractmethod - def get_environ(self): - """Return WSGI environ entries to be merged into each request.""" - raise NotImplementedError # pragma: no cover + def _get_library_version_environ(self): + """ + Get SSL library version information. + + Must be implemented by subclasses to provide adapter-specific + version strings. + + Returns: + dict: SSL_VERSION_INTERFACE and SSL_VERSION_LIBRARY + """ + raise NotImplementedError @abstractmethod - def makefile(self, sock, mode='r', bufsize=-1): - """Return socket file object.""" - raise NotImplementedError # pragma: no cover + def _get_optional_environ(self, conn): + """ + Get optional environment variables. + + Must be implemented by subclasses for adapter-specific handling + of optional fields like SNI, compression, etc. + + Returns: + dict: Optional SSL environment variables + """ + raise NotImplementedError diff --git a/cheroot/ssl/__init__.pyi b/cheroot/ssl/__init__.pyi index c595121546..2582632314 100644 --- a/cheroot/ssl/__init__.pyi +++ b/cheroot/ssl/__init__.pyi @@ -1,6 +1,30 @@ from abc import ABC, abstractmethod from typing import Any +DN_SEPARATOR: str +UTF8_ENCODING: str + +def parse_pyopenssl_cert_to_environ( + cert: Any, + key_prefix: str, +) -> dict[str, Any]: ... +def parse_x509_cert_to_environ( + cert: Any, + key_prefix: str, +) -> dict[str, Any]: ... + +class SSLEnvironMixin: + def _get_core_tls_environ( + self, + conn: Any, + ) -> dict[str, Any]: ... + def _get_server_cert_environ(self) -> dict[str, Any]: ... + def _get_client_cert_environ( + self, + conn: Any, + ssl_environ: dict[str, Any], + ) -> dict[str, Any]: ... + class Adapter(ABC): certificate: Any private_key: Any @@ -23,6 +47,4 @@ class Adapter(ABC): @abstractmethod def wrap(self, sock): ... @abstractmethod - def get_environ(self): ... - @abstractmethod - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict[str, Any]: ... diff --git a/cheroot/ssl/builtin.py b/cheroot/ssl/builtin.py index ed747ab6e3..f0ef58e66c 100644 --- a/cheroot/ssl/builtin.py +++ b/cheroot/ssl/builtin.py @@ -9,9 +9,11 @@ import socket import sys -import threading from contextlib import suppress +from . import Adapter +from .tls_socket import TLSSocket + try: import ssl @@ -19,148 +21,24 @@ ssl = None try: - from _pyio import DEFAULT_BUFFER_SIZE + from cryptography import x509 + from cryptography.hazmat.backends import default_backend except ImportError: - try: - from io import DEFAULT_BUFFER_SIZE - except ImportError: - DEFAULT_BUFFER_SIZE = -1 - -from .. import errors -from ..makefile import StreamReader, StreamWriter -from ..server import HTTPServer -from . import Adapter + x509 = None + default_backend = None -def _assert_ssl_exc_contains(exc, *msgs): - """Check whether SSL exception contains either of messages provided.""" - if len(msgs) < 1: - raise TypeError( - '_assert_ssl_exc_contains() requires ' - 'at least one message to be passed.', - ) - err_msg_lower = str(exc).lower() - return any(m.lower() in err_msg_lower for m in msgs) - - -def _loopback_for_cert_thread(context, server): - """Wrap a socket in ssl and perform the server-side handshake.""" - # As we only care about parsing the certificate, the failure of - # which will cause an exception in ``_loopback_for_cert``, - # we can safely ignore connection and ssl related exceptions. Ref: - # https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030 - with suppress(ssl.SSLError, OSError): - with context.wrap_socket( - server, - do_handshake_on_connect=True, - server_side=True, - ) as ssl_sock: - # in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server - # sends the client session tickets that can be used to - # resume the TLS session on a new connection without - # performing the full handshake again. session tickets are - # sent as a post-handshake message at some _unspecified_ - # time and thus a successful connection may be closed - # without the client having received the tickets. - # Unfortunately, on Windows (Python 3.8+), this is treated - # as an incomplete handshake on the server side and a - # ``ConnectionAbortedError`` is raised. - # TLS 1.3 support is still incomplete in Python 3.8; - # there is no way for the client to wait for tickets. - # While not necessary for retrieving the parsed certificate, - # we send a tiny bit of data over the connection in an - # attempt to give the server a chance to send the session - # tickets and close the connection cleanly. - # Note that, as this is essentially a race condition, - # the error may still occur ocasionally. - ssl_sock.send(b'0000') - - -def _loopback_for_cert( - certificate, - private_key, - certificate_chain, - *, - private_key_password=None, -): - """Create a loopback connection to parse a cert with a private key.""" - context = ssl.create_default_context(cafile=certificate_chain) - context.load_cert_chain( - certificate, - private_key, - password=private_key_password, - ) - context.check_hostname = False - context.verify_mode = ssl.CERT_NONE - - # Python 3+ Unix, Python 3.5+ Windows - client, server = socket.socketpair() - try: - # `wrap_socket` will block until the ssl handshake is complete. - # it must be called on both ends at the same time -> thread - # openssl will cache the peer's cert during a successful handshake - # and return it via `getpeercert` even after the socket is closed. - # when `close` is called, the SSL shutdown notice will be sent - # and then python will wait to receive the corollary shutdown. - thread = threading.Thread( - target=_loopback_for_cert_thread, - args=(context, server), - ) - try: - thread.start() - with context.wrap_socket( - client, - do_handshake_on_connect=True, - server_side=False, - ) as ssl_sock: - ssl_sock.recv(4) - return ssl_sock.getpeercert() - finally: - thread.join() - finally: - client.close() - server.close() - - -def _parse_cert( - certificate, - private_key, - certificate_chain, - *, - private_key_password=None, -): - """Parse a certificate.""" - # loopback_for_cert uses socket.socketpair which was only - # introduced in Python 3.0 for *nix and 3.5 for Windows - # and requires OS support (AttributeError, OSError) - # it also requires a private key either in its own file - # or combined with the cert (SSLError) - with suppress(AttributeError, ssl.SSLError, OSError): - return _loopback_for_cert( - certificate, - private_key, - certificate_chain, - private_key_password=private_key_password, - ) - - # KLUDGE: using an undocumented, private, test method to parse a cert - # unfortunately, it is the only built-in way without a connection - # as a private, undocumented method, it may change at any time - # so be tolerant of *any* possible errors it may raise - with suppress(Exception): - return ssl._ssl._test_decode_cert(certificate) - - return {} - - -def _sni_callback(sock, sni, context): - """Handle the SNI callback to tag the socket with the SNI.""" - sock.sni = sni - # return None to allow the TLS negotiation to continue +from .. import errors +from . import parse_x509_cert_to_environ class BuiltinSSLAdapter(Adapter): - """Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot.""" + """ + Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot. + + This adapter uses TLSSocket internally to provide a consistent + interface for SSL/TLS connections. + """ certificate = None """The file name of the server SSL certificate.""" @@ -235,11 +113,13 @@ def __init__( *, private_key_password=None, ): - """Set up context in addition to base class properties if available.""" + """Initialize builtin SSL Adapter instance.""" if ssl is None: - raise ImportError('You must install the ssl module to use HTTPS.') + raise ImportError( + 'You must have ssl module available to use HTTPS.', + ) - super(BuiltinSSLAdapter, self).__init__( + super().__init__( certificate, private_key, certificate_chain, @@ -247,162 +127,378 @@ def __init__( private_key_password=private_key_password, ) - self.context = ssl.create_default_context( - purpose=ssl.Purpose.CLIENT_AUTH, - cafile=certificate_chain, - ) - self.context.load_cert_chain( - certificate, - private_key, - password=private_key_password, - ) - if self.ciphers is not None: - self.context.set_ciphers(ciphers) - - self._server_env = self._make_env_cert_dict( - 'SSL_SERVER', - _parse_cert( - certificate, - private_key, - self.certificate_chain, - private_key_password=private_key_password, - ), - ) - if not self._server_env: - return - cert = None - with open(certificate) as f: - cert = f.read() - - # strip off any keys by only taking the first certificate - cert_start = cert.find(ssl.PEM_HEADER) - if cert_start == -1: - return - cert_end = cert.find(ssl.PEM_FOOTER, cert_start) - if cert_end == -1: - return - cert_end += len(ssl.PEM_FOOTER) - self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end] + self._context = None + self._context = self._create_context() + + def bind(self, sock): + """Prepare the server socket.""" + return sock # Context already created + + def wrap(self, sock): + """ + Wrap client socket with SSL and return environ entries. + + Args: + sock: Raw socket to wrap with TLS + + Returns: + tuple: (TLSSocket, ssl_environ_dict) + """ + if self._check_for_plain_http(sock): + raise errors.NoSSLError + + tls_socket = self._wrap_with_builtin(sock) + ssl_environ = self.get_environ(tls_socket) + return tls_socket, ssl_environ + + def _wrap_with_builtin(self, raw_socket, server_side=True): + """ + Create a TLSSocket using Python's built-in ssl module. + + Args: + raw_socket: The raw socket to wrap + server_side: True if this is the server side + + Returns: + TLSSocket: Wrapped socket ready for secure I/O + """ + try: + wrapped_ssl_socket = self._create_ssl_socket( + raw_socket, + server_side, + ) + self._perform_handshake(wrapped_ssl_socket, raw_socket) + + underlying_socket = wrapped_ssl_socket + if hasattr(wrapped_ssl_socket, '_sock'): + underlying_socket = wrapped_ssl_socket._sock + + return TLSSocket( + ssl_socket=wrapped_ssl_socket, + raw_socket=underlying_socket, + context=self.context, + ) + except errors.NoSSLError: + # Plain HTTP detected, let it propagate for proper error response + raise + except TimeoutError: + # Handshake timeout, let it propagate + with suppress(Exception): + raw_socket.close() + raise + except (ssl.SSLError, OSError) as e: + # SSL handshake or socket error - clean up and raise + with suppress(Exception): + raw_socket.close() + raise errors.FatalSSLAlert(f'SSL wrapping failed: {e}') from e + + def _check_for_plain_http(self, raw_socket): + """Check if the client sent plain HTTP by peeking at first bytes. + + This is a best-effort check to provide a helpful error message when + clients accidentally use HTTP on an HTTPS port. If we can't detect + plain HTTP (timeout, no data yet, etc), we return False and let the + SSL handshake proceed, which will fail with its own error. + + Returns: + bool: True if plain HTTP is detected, False otherwise + """ + PEEK_BYTES = 16 + PEEK_TIMEOUT = 0.5 + + original_timeout = raw_socket.gettimeout() + try: + raw_socket.settimeout(PEEK_TIMEOUT) + first_bytes = raw_socket.recv(PEEK_BYTES, socket.MSG_PEEK) + + if not first_bytes: + return False + + http_methods = ( + b'GET ', + b'POST ', + b'PUT ', + b'DELETE ', + b'HEAD ', + b'OPTIONS ', + b'PATCH ', + b'CONNECT ', + b'TRACE ', + ) + return any( + first_bytes.startswith(method) for method in http_methods + ) + except (OSError, socket.timeout): + return False + finally: + raw_socket.settimeout(original_timeout) + + def _create_ssl_socket(self, raw_socket, server_side): + """Create SSL socket without handshake.""" + try: + # Manual handshake for error handling + return self.context.wrap_socket( + raw_socket, + do_handshake_on_connect=False, + server_side=server_side, + ) + except ssl.SSLError as e: + raise errors.FatalSSLAlert( + f'Error creating SSL socket: {e}', + ) from e + + def _perform_handshake(self, ssl_socket, raw_socket): + """Perform SSL handshake with error handling and retries.""" + HANDSHAKE_TIMEOUT = 5.0 + + # Set timeout on the SSL socket for the handshake + original_timeout = ssl_socket.gettimeout() + ssl_socket.settimeout(HANDSHAKE_TIMEOUT) + + try: + while True: + try: # noqa: WPS225 + ssl_socket.do_handshake() + return + except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: + direction = ( + 'read' + if isinstance(e, ssl.SSLWantReadError) + else 'write' + ) + self._wait_for_handshake_data(raw_socket, direction) + except socket.timeout as e: + raise errors.NoSSLError( + 'SSL handshake timeout.', + ) from e + except ssl.SSLEOFError as e: + raise errors.NoSSLError( + 'Peer closed connection during handshake.', + ) from e + except ssl.SSLError as e: + self._handle_ssl_error(e) + except OSError as e: + raise errors.FatalSSLAlert( + f'TCP error during handshake: {e}', + ) from e + finally: + # Restore original timeout + with suppress(Exception): + ssl_socket.settimeout(original_timeout) + + def _wait_for_handshake_data(self, raw_socket, direction): + """Wait for socket to be ready for read or write during handshake.""" + import select + + HANDSHAKE_TIMEOUT = 5.0 + fileno = raw_socket.fileno() + + if direction == 'read': + ready = select.select([fileno], [], [], HANDSHAKE_TIMEOUT)[0] + else: # write + ready = select.select([], [fileno], [], HANDSHAKE_TIMEOUT)[1] + + if not ready: + raise TimeoutError( + f'Handshake failed: Peer did not send expected data ({direction}).', + ) + + def _handle_ssl_error(self, error): + """Handle SSL errors during handshake.""" + err_str = str(error).lower() + + # Check for common patterns indicating plain HTTP + if any( + pattern in err_str + for pattern in ( + 'wrong version number', + 'http request', + 'unknown protocol', + ) + ): + raise errors.NoSSLError( + 'Client sent plain HTTP request', + ) from error + + raise errors.FatalSSLAlert( + f'Fatal SSL error during handshake: {error}', + ) from error @property def context(self): - """:py:class:`~ssl.SSLContext` that will be used to wrap sockets.""" + """Get the SSL context.""" return self._context @context.setter - def context(self, context): - """Set the ssl ``context`` to use.""" - self._context = context - # Python 3.7+ - # if a context is provided via `cherrypy.config.update` then - # `self.context` will be set after `__init__` - # use a property to intercept it to add an SNI callback - # but don't override the user's callback - # TODO: chain callbacks - with suppress(AttributeError): - if ssl.HAS_SNI and context.sni_callback is None: - context.sni_callback = _sni_callback + def context(self, value): + """Set the SSL context (for testing).""" + self._context = value + + def _create_context(self): + """Return an py:class:`ssl.SSLContext` from self attributes.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + + # Only attempt to load the key/cert chain if a certificate + # path is available. + if self.certificate: + try: + ctx.load_cert_chain( + self.certificate, + self.private_key, + self.private_key_password, + ) + except FileNotFoundError as file_err: + raise FileNotFoundError( + f'SSL certificate file not found: {file_err}', + ) from file_err - def bind(self, sock): - """Wrap and return the given socket.""" - return super(BuiltinSSLAdapter, self).bind(sock) + # Load CA Trust/Verification Chain + # (Optional, independent of server cert) + # This is needed for verifying client certificates + # or connecting to other services. + if self.certificate_chain: + ctx.load_verify_locations(cafile=self.certificate_chain) + + # Set Ciphers (Optional, independent of server cert) + if self.ciphers: + ctx.set_ciphers(self.ciphers) + + return ctx + + # ======================================================================== + # Adapter-specific environment variable methods + # ======================================================================== + + def _get_library_version_environ(self): + """ + Get SSL library version information. + + Overrides base class method to provide builtin ssl module version. + """ + python_version = sys.version.split()[0] + return { + 'SSL_VERSION_INTERFACE': f'Python/{python_version} {ssl.OPENSSL_VERSION}', + 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, + } + + def _get_optional_environ(self, conn): + """ + Get optional environment variables. + + Overrides base class method for builtin ssl-specific handling. + """ + environ = {} + + # Compression (note: most modern OpenSSL builds disable compression) + try: + compression = conn.compression() + if compression: + environ['SSL_COMPRESS_METHOD'] = compression + except AttributeError: + # TLSSocket might not have compression method + ... + + # SNI (Server Name Indication) if available + try: + server_hostname = conn.server_hostname + if server_hostname: + environ['SSL_TLS_SNI'] = server_hostname + except AttributeError: + ... + + return environ + + def _get_server_cert_environ(self): + """Get server certificate info using builtin ssl certificate parsing.""" + if not self.certificate: + return {} + + # Check if cryptography is available + if x509 is None: + return {} - def wrap(self, sock): - """Wrap and return the given socket, plus WSGI environ entries.""" try: - s = self.context.wrap_socket( - sock, - do_handshake_on_connect=True, - server_side=True, + with open(self.certificate, 'rb') as cert_file: + cert_data = cert_file.read() + + cert = x509.load_pem_x509_certificate( + cert_data, + default_backend(), ) - except ( - ssl.SSLEOFError, - ssl.SSLZeroReturnError, - ) as tls_connection_drop_error: - raise errors.FatalSSLAlert( - *tls_connection_drop_error.args, - ) from tls_connection_drop_error - except ssl.SSLError as generic_tls_error: - peer_speaks_plain_http_over_https = ( - generic_tls_error.errno == ssl.SSL_ERROR_SSL - and _assert_ssl_exc_contains(generic_tls_error, 'http request') + + return parse_x509_cert_to_environ(cert, 'SSL_SERVER') + + except Exception: + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """Populate the WSGI environment with client certificate details.""" + # 1. Access the raw ssl.SSLSocket object + try: + # 'conn' is the TLSSocket wrapper; '_sock' is the raw ssl.SSLSocket + raw_ssl_socket = conn._sock + except AttributeError: + # If the socket is missing or already closed + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # 2. Get the peer certificate details + try: + # getpeercert() returns a dict if a cert was presented, + # None otherwise. + peer_cert_details = raw_ssl_socket.getpeercert(binary_form=False) + # Also get the binary (DER) form to convert to PEM + peer_cert_binary = raw_ssl_socket.getpeercert(binary_form=True) + except ssl.SSLError: + # This occurs if verification failed during the handshake + ssl_environ['SSL_CLIENT_VERIFY'] = 'FAILURE' + return ssl_environ + except Exception: + # Catch any other socket errors + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # --- Check Verification Status --- + if peer_cert_details: + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + else: + # No cert presented + ssl_environ['SSL_CLIENT_VERIFY'] = 'NONE' + return ssl_environ + + # --- Add the PEM-encoded certificate --- + if peer_cert_binary: + ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( + peer_cert_binary, ) - if peer_speaks_plain_http_over_https: - reraised_connection_drop_exc_cls = errors.NoSSLError - else: - reraised_connection_drop_exc_cls = errors.FatalSSLAlert - raise reraised_connection_drop_exc_cls( - *generic_tls_error.args, - ) from generic_tls_error - except OSError as tcp_connection_drop_error: - raise errors.FatalSSLAlert( - *tcp_connection_drop_error.args, - ) from tcp_connection_drop_error - - return s, self.get_environ(s) - - def get_environ(self, sock): - """Create WSGI environ entries to be merged into each request.""" - cipher = sock.cipher() - ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', - 'SSL_PROTOCOL': cipher[1], - 'SSL_CIPHER': cipher[0], - 'SSL_CIPHER_EXPORT': '', - 'SSL_CIPHER_USEKEYSIZE': cipher[2], - 'SSL_VERSION_INTERFACE': '%s Python/%s' - % ( - HTTPServer.version, - sys.version, - ), - 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, - 'SSL_CLIENT_VERIFY': 'NONE', - # 'NONE' - client did not provide a cert (overriden below) - } + # --- Populate Metadata using existing utility methods --- - # Python 3.3+ - with suppress(AttributeError): - compression = sock.compression() - if compression is not None: - ssl_environ['SSL_COMPRESS_METHOD'] = compression - - # Python 3.6+ - with suppress(AttributeError): - ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex() - with suppress(AttributeError): - target_cipher = cipher[:2] - for cip in sock.context.get_ciphers(): - if target_cipher == (cip['name'], cip['protocol']): - ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits'] - break - - # Python 3.7+ sni_callback - with suppress(AttributeError): - ssl_environ['SSL_TLS_SNI'] = sock.sni - - if self.context and self.context.verify_mode != ssl.CERT_NONE: - client_cert = sock.getpeercert() - if client_cert: - # builtin ssl **ALWAYS** validates client certificates - # and terminates the connection on failure - ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' - ssl_environ.update( - self._make_env_cert_dict('SSL_CLIENT', client_cert), - ) - ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( - sock.getpeercert(binary_form=True), - ).strip() + # 3. Populate Subject DN + subject_dn_nested = peer_cert_details.get('subject', []) + subject_env = self._make_env_dn_dict( + env_prefix='SSL_CLIENT_S_DN', + cert_value=subject_dn_nested, + ) + ssl_environ.update(subject_env) - ssl_environ.update(self._server_env) + # 4. Populate Issuer DN + issuer_dn_nested = peer_cert_details.get('issuer', []) + issuer_env = self._make_env_dn_dict( + env_prefix='SSL_CLIENT_I_DN', + cert_value=issuer_dn_nested, + ) + ssl_environ.update(issuer_env) - # not supplied by the Python standard library (as of 3.8) - # - SSL_SESSION_RESUMED - # - SSL_SECURE_RENEG - # - SSL_CLIENT_CERT_CHAIN_n - # - SRP_USER - # - SRP_USERINFO + # 5. Populate other cert details + ssl_environ['SSL_CLIENT_M_VERSION'] = str( + peer_cert_details.get('version', ''), + ) + ssl_environ['SSL_CLIENT_M_SERIAL'] = str( + peer_cert_details.get('serialNumber', ''), + ) return ssl_environ @@ -482,8 +578,10 @@ def _make_env_dn_dict(self, env_prefix, cert_value): dn_attrs.setdefault(attr_code, []) dn_attrs[attr_code].append(val) + dn_string = '/'.join(dn) + env = { - env_prefix: ','.join(dn), + env_prefix: '/%s' % (dn_string,), } for attr_code, values in dn_attrs.items(): env['%s_%s' % (env_prefix, attr_code)] = ','.join(values) @@ -492,8 +590,3 @@ def _make_env_dn_dict(self, env_prefix, cert_value): for i, val in enumerate(values): env['%s_%s_%i' % (env_prefix, attr_code, i)] = val return env - - def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): - """Return socket file object.""" - cls = StreamReader if 'r' in mode else StreamWriter - return cls(sock, mode, bufsize) diff --git a/cheroot/ssl/builtin.pyi b/cheroot/ssl/builtin.pyi index b05aaf5ad7..cb5f508d26 100644 --- a/cheroot/ssl/builtin.pyi +++ b/cheroot/ssl/builtin.pyi @@ -2,8 +2,6 @@ from typing import Any from . import Adapter -DEFAULT_BUFFER_SIZE: int - class BuiltinSSLAdapter(Adapter): CERT_KEY_TO_ENV: Any CERT_KEY_TO_LDAP_CODE: Any @@ -22,5 +20,4 @@ class BuiltinSSLAdapter(Adapter): def context(self, context) -> None: ... def bind(self, sock): ... def wrap(self, sock): ... - def get_environ(self, sock): ... - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict: ... diff --git a/cheroot/ssl/pyopenssl.py b/cheroot/ssl/pyopenssl.py index 8a041a79b8..61cc1dd7c1 100644 --- a/cheroot/ssl/pyopenssl.py +++ b/cheroot/ssl/pyopenssl.py @@ -13,8 +13,7 @@ Method One ---------- - * :py:attr:`ssl_adapter.context - `: an instance of + * ``ssl_adapter.context``: an instance of :py:class:`SSL.Context `. If this is not None, it is assumed to be an :py:class:`SSL.Context @@ -40,22 +39,24 @@ `: the file name of the server's private key file. -Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context -` is :py:data:`None`, -but ``.private_key`` and ``.certificate`` are both given and valid, they -will be read, and the context will be automatically created from them. +Both are :py:data:`None` by default. If ``ssl_adapter.context`` +is :py:data:`None`, but ``.private_key`` and ``.certificate`` are both +given and valid, they will be read, and the context will be automatically +created from them. .. spelling:: pyopenssl """ -import socket +import select import sys -import threading -import time +from contextlib import suppress from warnings import warn as _warn +from . import Adapter, parse_pyopenssl_cert_to_environ +from .tls_socket import TLSSocket + try: import OpenSSL.version @@ -67,211 +68,21 @@ ssl_conn_type = SSL.ConnectionType except ImportError: SSL = None + crypto = None -import contextlib - -from .. import ( - errors, - server as cheroot_server, -) +from .. import errors from ..makefile import StreamReader, StreamWriter -from . import Adapter - - -class SSLFileobjectMixin: - """Base mixin for a TLS socket stream.""" - - ssl_timeout = 3 - ssl_retry = 0.01 - - # FIXME: - def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 - """Wrap the given call with TLS error-trapping. - - is_reader: if False EOF errors will be raised. If True, EOF errors - will return "" (to emulate normal sockets). - """ - start = time.time() - while True: - try: - return call(*args, **kwargs) - except SSL.WantReadError: - # Sleep and try again. This is dangerous, because it means - # the rest of the stack has no way of differentiating - # between a "new handshake" error and "client dropped". - # Note this isn't an endless loop: there's a timeout below. - # Ref: https://stackoverflow.com/a/5133568/595220 - time.sleep(self.ssl_retry) - except SSL.WantWriteError: - time.sleep(self.ssl_retry) - except SSL.SysCallError as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' - - errnum = e.args[0] - if is_reader and errnum in errors.socket_errors_to_ignore: - return b'' - raise socket.error(errnum) - except SSL.Error as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' - thirdarg = None - with contextlib.suppress(IndexError): - thirdarg = e.args[0][0][2] - - if thirdarg == 'http request': - # The client is talking HTTP to an HTTPS server. - raise errors.NoSSLError - - raise errors.FatalSSLAlert(*e.args) - - if time.time() - start > self.ssl_timeout: - raise socket.timeout('timed out') - - def recv(self, size): - """Receive message of a size from the socket.""" - return self._safe_call( - True, - super(SSLFileobjectMixin, self).recv, - size, - ) - - def readline(self, size=-1): - """Receive message of a size from the socket. - - Matches the following interface: - https://docs.python.org/3/library/io.html#io.IOBase.readline - """ - return self._safe_call( - True, - super(SSLFileobjectMixin, self).readline, - size, - ) - - def sendall(self, *args, **kwargs): - """Send whole message to the socket.""" - return self._safe_call( - False, - super(SSLFileobjectMixin, self).sendall, - *args, - **kwargs, - ) - def send(self, *args, **kwargs): - """Send some part of message to the socket.""" - return self._safe_call( - False, - super(SSLFileobjectMixin, self).send, - *args, - **kwargs, - ) - - -class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): +class SSLFileobjectStreamReader(StreamReader): """SSL file object attached to a socket object.""" -class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): +class SSLFileobjectStreamWriter(StreamWriter): """SSL file object attached to a socket object.""" -class SSLConnectionProxyMeta: - """Metaclass for generating a bunch of proxy methods.""" - - def __new__(mcl, name, bases, nmspc): - """Attach a list of proxy methods to a new class.""" - proxy_methods = ( - 'get_context', - 'pending', - 'send', - 'write', - 'recv', - 'read', - 'renegotiate', - 'bind', - 'listen', - 'connect', - 'accept', - 'setblocking', - 'fileno', - 'close', - 'get_cipher_list', - 'getpeername', - 'getsockname', - 'getsockopt', - 'setsockopt', - 'makefile', - 'get_app_data', - 'set_app_data', - 'state_string', - 'sock_shutdown', - 'get_peer_certificate', - 'want_read', - 'want_write', - 'set_connect_state', - 'set_accept_state', - 'connect_ex', - 'sendall', - 'settimeout', - 'gettimeout', - 'shutdown', - ) - proxy_methods_no_args = ('shutdown',) - - proxy_props = ('family',) - - def lock_decorator(method): - """Create a proxy method for a new class.""" - - def proxy_wrapper(self, *args): - self._lock.acquire() - try: - new_args = ( - args[:] if method not in proxy_methods_no_args else [] - ) - return getattr(self._ssl_conn, method)(*new_args) - finally: - self._lock.release() - - return proxy_wrapper - - for m in proxy_methods: - nmspc[m] = lock_decorator(m) - nmspc[m].__name__ = m - - def make_property(property_): - """Create a proxy method for a new class.""" - - def proxy_prop_wrapper(self): - return getattr(self._ssl_conn, property_) - - proxy_prop_wrapper.__name__ = property_ - return property(proxy_prop_wrapper) - - for p in proxy_props: - nmspc[p] = make_property(p) - - # Doesn't work via super() for some reason. - # Falling back to type() instead: - return type(name, bases, nmspc) - - -class SSLConnection(metaclass=SSLConnectionProxyMeta): - r"""A thread-safe wrapper for an ``SSL.Connection``. - - :param tuple args: the arguments to create the wrapped \ - :py:class:`SSL.Connection(*args) \ - ` - """ - - def __init__(self, *args): - """Initialize SSLConnection instance.""" - self._ssl_conn = SSL.Connection(*args) - self._lock = threading.RLock() - - -class pyOpenSSLAdapter(Adapter): +class pyOpenSSLAdapter(Adapter): # noqa: WPS214 """A wrapper for integrating :doc:`pyOpenSSL `.""" certificate = None @@ -286,11 +97,6 @@ class pyOpenSSLAdapter(Adapter): This is needed for cheaper "chained root" TLS certificates, and should be left as :py:data:`None` if not required.""" - context = None - """ - An instance of :py:class:`SSL.Context `. - """ - ciphers = None """The ciphers list of TLS.""" @@ -310,30 +116,150 @@ def __init__( if SSL is None: raise ImportError('You must install pyOpenSSL to use HTTPS.') - super(pyOpenSSLAdapter, self).__init__( + super().__init__( certificate, private_key, certificate_chain, ciphers, private_key_password=private_key_password, ) - self._environ = None + self.context = None def bind(self, sock): - """Wrap and return the given socket.""" - if self.context is None: - self.context = self.get_context() - conn = SSLConnection(self.context, sock) - self._environ = self.get_environ() - return conn + """ + Prepare the server socket. + + Ensures that the SSL context object is created + and fully configured. For Method One the caller + supplies the context at ``__init()__`` but for + Method Two we construct from certificate files. + """ + if self._context is None: + # Method Two + _ = self.context # triggers initialization via property + return sock def wrap(self, sock): - """Wrap and return the given socket, plus WSGI environ entries.""" - # pyOpenSSL doesn't perform the handshake until the first read/write - # forcing the handshake to complete tends to result in the connection - # closing so we can't reliably access protocol/client cert for the env - return sock, self._environ.copy() + """Wrap client socket with SSL and return environ entries.""" + tls_socket = self._wrap_with_pyopenssl(sock) + ssl_environ = self.get_environ(tls_socket) + return tls_socket, ssl_environ + + def _wrap_with_pyopenssl(self, raw_socket, server_side=True): + """Create a TLSSocket wrapping a PyOpenSSL connection.""" + pyopenssl_ssl_object = self._create_pyopenssl_connection(raw_socket) + self._configure_connection_state(pyopenssl_ssl_object, server_side) + self._perform_handshake(pyopenssl_ssl_object, raw_socket) + + # lgtm[py/insecure-protocol] + return TLSSocket( + ssl_socket=pyopenssl_ssl_object, + raw_socket=raw_socket, + context=self.context, + ) + + def _create_pyopenssl_connection(self, raw_socket): + """Create PyOpenSSL connection object.""" + try: + ssl_object = ssl_conn_type(self.context, raw_socket) + except SSL.Error as e: + raise errors.FatalSSLAlert( + f'Error creating pyOpenSSL connection: {e}', + ) from e + + ssl_object.setblocking(True) + return ssl_object + + def _configure_connection_state(self, ssl_object, server_side): + """Set connection to server or client mode.""" + if server_side: + ssl_object.set_accept_state() + else: + ssl_object.set_connect_state() + + def _perform_handshake(self, ssl_object, raw_socket): + """Perform SSL handshake with error handling.""" + while True: + try: + ssl_object.do_handshake() + return + except SSL.WantReadError: + self._wait_for_handshake_data(raw_socket) + except SSL.ZeroReturnError as e: + raise errors.NoSSLError( + 'Peer closed connection during handshake.', + ) from e + except SSL.Error as e: + self._handle_ssl_error(e) + except OSError as e: + raise errors.FatalSSLAlert( + f'TCP error during handshake: {e}', + ) from e + + def _wait_for_handshake_data(self, raw_socket): + """Wait for peer to send data during handshake.""" + HANDSHAKE_TIMEOUT = 5.0 + fileno = raw_socket.fileno() + ready_to_read, _, _ = select.select( + [fileno], + [], + [], + HANDSHAKE_TIMEOUT, + ) + + if not ready_to_read: + raise TimeoutError( + 'Handshake failed: Peer did not send expected data.', + ) + + def _handle_ssl_error(self, error): + """Handle SSL errors during handshake.""" + err_str = str(error) + if 'http request' in err_str: + raise errors.NoSSLError( + 'Client sent plain HTTP request', + ) from error + raise errors.FatalSSLAlert( + f'Fatal SSL error during handshake: {err_str}', + ) from error + + @property + def context(self): + """Get the SSL context.""" + if self._context is None: + # Method Two: auto-create from certificate/private_key + self._context = self.get_context() + return self._context + + @context.setter + def context(self, value): + """Set the SSL context (Method One).""" + self._context = value + + def get_context(self): + """Return an SSL.Context from self attributes. + + Uses TLS_SERVER_METHOD which supports TLS 1.0-1.3, but immediately + disables insecure protocols (SSLv2, SSLv3, TLSv1.0, TLSv1.1) via + set_options(), ensuring only TLS 1.2+ is accepted. + """ + c = SSL.Context(SSL.TLS_SERVER_METHOD) # nosec B502 + + # Disable all insecure protocols (SSLv2, SSLv3, TLSv1.0, TLSv1.1) + c.set_options( + SSL.OP_NO_SSLv2 + | SSL.OP_NO_SSLv3 + | SSL.OP_NO_TLSv1 + | SSL.OP_NO_TLSv1_1, + ) + + c.set_passwd_cb(self._password_callback, self.private_key_password) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c def _password_callback( self, @@ -343,7 +269,7 @@ def _password_callback( /, ): """Pass a passphrase to password protected private key.""" - b_password = b'' # returning a falsy value communicates an error + b_password = b'' if isinstance(password, str): b_password = password.encode('utf-8') elif isinstance(password, bytes): @@ -357,95 +283,81 @@ def _password_callback( UserWarning, stacklevel=1, ) - return b_password - def get_context(self): - """Return an ``SSL.Context`` from self attributes. - - Ref: :py:class:`SSL.Context ` - """ - # See https://code.activestate.com/recipes/442473/ - c = SSL.Context(SSL.SSLv23_METHOD) - c.set_passwd_cb(self._password_callback, self.private_key_password) - c.use_privatekey_file(self.private_key) - if self.certificate_chain: - c.load_verify_locations(self.certificate_chain) - c.use_certificate_file(self.certificate) - return c + # ======================================================================== + # Adapter-specific environment variable methods + # ======================================================================== - def get_environ(self): - """Return WSGI environ entries to be merged into each request.""" - ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', + def _get_library_version_environ(self): + """Get SSL library version information for pyOpenSSL.""" + return { 'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % ( - cheroot_server.HTTPServer.version, + 'Cheroot', OpenSSL.version.__title__, OpenSSL.version.__version__, sys.version, ), - 'SSL_VERSION_LIBRARY': SSL.SSLeay_version( - SSL.SSLEAY_VERSION, - ).decode(), + 'SSL_VERSION_LIBRARY': SSL.OpenSSL_version( + SSL.OPENSSL_VERSION, + ).decode('ascii'), } - if self.certificate: - # Server certificate attributes + def _get_optional_environ(self, conn): + """Get optional environment variables for pyOpenSSL.""" + # pyOpenSSL doesn't easily expose SNI or compression info + # Could be extended in the future + return {} + + def _get_server_cert_environ(self): + """Get server certificate info using pyOpenSSL certificate parsing.""" + if not self.certificate or crypto is None: + return {} + + try: with open(self.certificate, 'rb') as cert_file: - cert = crypto.load_certificate( - crypto.FILETYPE_PEM, - cert_file.read(), + cert_data = cert_file.read() + + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_data) + return parse_pyopenssl_cert_to_environ(cert, 'SSL_SERVER') + + except Exception: + # If certificate parsing fails, return empty dict + return {} + + def _get_client_cert_environ(self, conn, ssl_environ): + """Add client certificate details using pyOpenSSL.""" + with suppress(Exception): + # Get the peer certificate from the pyOpenSSL connection + # conn is a TLSSocket, so we need to access the + # underlying SSL socket + ssl_socket = conn._ssl_socket + + # Check if peer verification was enabled + if ssl_socket.get_context().get_verify_mode() == SSL.VERIFY_NONE: + return ssl_environ + + client_cert = ssl_socket.get_peer_certificate() + + if client_cert: + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + ssl_environ.update( + parse_pyopenssl_cert_to_environ( + client_cert, + 'SSL_CLIENT', + ), ) - ssl_environ.update( - { - 'SSL_SERVER_M_VERSION': cert.get_version(), - 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), - # 'SSL_SERVER_V_START': - # Validity of server's certificate (start time), - # 'SSL_SERVER_V_END': - # Validity of server's certificate (end time), - }, - ) - - for prefix, dn in [ - ('I', cert.get_issuer()), - ('S', cert.get_subject()), - ]: - # X509Name objects don't seem to have a way to get the - # complete DN string. Use str() and slice it instead, - # because str(dn) == "" - dnstr = str(dn)[18:-2] - - wsgikey = 'SSL_SERVER_%s_DN' % prefix - ssl_environ[wsgikey] = dnstr - - # The DN should be of the form: /k1=v1/k2=v2, but we must allow - # for any value to contain slashes itself (in a URL). - while dnstr: - pos = dnstr.rfind('=') - dnstr, value = dnstr[:pos], dnstr[pos + 1 :] - pos = dnstr.rfind('/') - dnstr, key = dnstr[:pos], dnstr[pos + 1 :] - if key and value: - wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) - ssl_environ[wsgikey] = value + # Get PEM representation of certificate + pem_cert = ( + crypto.dump_certificate( + crypto.FILETYPE_PEM, + client_cert, + ) + .decode('ascii') + .strip() + ) + ssl_environ['SSL_CLIENT_CERT'] = pem_cert return ssl_environ - - def makefile(self, sock, mode='r', bufsize=-1): - """Return socket file object.""" - cls = ( - SSLFileobjectStreamReader - if 'r' in mode - else SSLFileobjectStreamWriter - ) - if SSL and isinstance(sock, ssl_conn_type): - wrapped_socket = cls(sock, mode, bufsize) - wrapped_socket.ssl_timeout = sock.gettimeout() - return wrapped_socket - # This is from past: - # TODO: figure out what it's meant for - return cheroot_server.CP_fileobject(sock, mode, bufsize) diff --git a/cheroot/ssl/pyopenssl.pyi b/cheroot/ssl/pyopenssl.pyi index 59dae05ae7..f1472bdc08 100644 --- a/cheroot/ssl/pyopenssl.pyi +++ b/cheroot/ssl/pyopenssl.pyi @@ -7,22 +7,8 @@ from . import Adapter ssl_conn_type: Type[SSL.Connection] -class SSLFileobjectMixin: - ssl_timeout: int - ssl_retry: float - def recv(self, size): ... - def readline(self, size: int = ...): ... - def sendall(self, *args, **kwargs): ... - def send(self, *args, **kwargs): ... - -class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): ... # type:ignore[misc] -class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): ... # type:ignore[misc] - -class SSLConnectionProxyMeta: - def __new__(mcl, name, bases, nmspc): ... - -class SSLConnection: - def __init__(self, *args) -> None: ... +class SSLFileobjectStreamReader(StreamReader): ... # type:ignore[misc] +class SSLFileobjectStreamWriter(StreamWriter): ... # type:ignore[misc] class pyOpenSSLAdapter(Adapter): def __init__( @@ -43,6 +29,5 @@ class pyOpenSSLAdapter(Adapter): password: bytes | str | None, /, ) -> bytes: ... - def get_environ(self): ... - def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... + def get_environ(self, conn) -> dict: ... def get_context(self) -> SSL.Context: ... diff --git a/cheroot/ssl/tls_socket.py b/cheroot/ssl/tls_socket.py new file mode 100644 index 0000000000..c0051ed4aa --- /dev/null +++ b/cheroot/ssl/tls_socket.py @@ -0,0 +1,492 @@ +""" +A unified SSL/TLS socket layer for Cheroot. + +This module provides a TLSSocket class that abstracts over +different SSL/TLS implementations, such as Python's built-in ssl module +and pyOpenSSL. It offers a consistent interface for the rest of +the Cheroot server code. +""" + +import errno +import io +import os +import socket +import ssl +import time +from contextlib import suppress + +from .. import errors + + +try: + from OpenSSL import SSL, crypto +except ImportError: + SSL = None # type: ignore[assignment] + crypto = None # type: ignore[assignment] + ssl_conn_type = None # type: ignore[misc] +else: + # If the import succeeded, proceed with secondary checks + # Use a separate try/except block for the connection type logic + try: # noqa: WPS505 + ssl_conn_type = SSL.Connection # type: ignore[misc] + except AttributeError: + # Fallback to older name if 'Connection' is not found + ssl_conn_type = SSL.ConnectionType # type: ignore[attr-defined] + +_OPENSSL_PROTOCOL_MAP = { + 769: 'TLSv1', + 770: 'TLSv1.1', + 771: 'TLSv1.2', + 772: 'TLSv1.2', + 773: 'TLSv1.3', +} + + +class TLSSocket(io.RawIOBase): # noqa: PLR0904 # pylint: disable=too-many-public-methods + """ + Lightweight wrapper around SSL/TLS sockets. + + Provides a uniform interface over both :class:`ssl.SSLSocket` and + :class:`OpenSSL.SSL.Connection` objects, ensuring consistent I/O + stream handling for Cheroot with proper SSL error handling. + """ + + def __init__(self, ssl_socket, raw_socket, context): + """ + Initialize TLS socket wrapper. + + Args: + ssl_socket: SSL/TLS wrapped socket (SSLSocket or pyOpenSSL Conn) + raw_socket: The underlying raw socket + context: The SSL context (SSLContext or pyOpenSSL Context) + """ + self._ssl_socket = ssl_socket + self._sock = raw_socket + self.context = context + self.ssl_retry = 0.01 # Retry interval for SSL errors + self.ssl_retry_max = 0.1 # Maximum time to retry before timeout + self._is_closed = False + + super().__init__() # Initialize RawIOBase + + # ================================================================ + # Properties - delegate to underlying socket or return stored values + # ================================================================ + + @property + def family(self): + """Get socket family.""" + return getattr(self._sock, 'family', socket.AF_INET) + + @property + def type(self): + """Get socket type.""" + return getattr(self._sock, 'type', socket.SOCK_STREAM) + + @property + def proto(self): + """Get socket protocol.""" + return getattr(self._sock, 'proto', 0) + + @property + def _closed(self): + """Check if the connection is closed.""" + if self._sock is None: + return True + + try: + fd = self._sock.fileno() + os.fstat(fd) + return False + except (OSError, AttributeError) as sockerr: + if isinstance(sockerr, OSError) and sockerr.errno == errno.EBADF: + return True + return True + + @property + def closed(self): + """Public closed property.""" + return self._is_closed + + @property + def scheme(self): + """Signal to Cheroot that this is an HTTPS connection.""" + return 'https' + + # ================================================================ + # SSL Error Handling + # ================================================================ + + def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 + r""" + Wrap the given call with TLS error-trapping. + + This handles transient SSL errors like WantReadError and WantWriteError + by retrying with a small sleep interval. + """ + if not SSL: + # If pyOpenSSL not available, just call directly + return call(*args, **kwargs) + + start = time.time() + + while True: + try: + return call(*args, **kwargs) + + except SSL.WantReadError: + # SSL needs more data to complete operation + time.sleep(self.ssl_retry) + if time.time() - start > self.ssl_retry_max: + raise socket.timeout('SSL WantReadError retry timeout') + + except SSL.WantWriteError: + # SSL needs to write data before continuing + time.sleep(self.ssl_retry) + if time.time() - start > self.ssl_retry_max: + raise socket.timeout('SSL WantWriteError retry timeout') + + except SSL.SysCallError as sys_err: + # System call error - check if it's ignorable + if is_reader and sys_err.args == (-1, 'Unexpected EOF'): + return b'' + + errnum = sys_err.args[0] if sys_err.args else -1 + if is_reader and errnum in errors.socket_errors_to_ignore: + return b'' + + raise socket.error(errnum) + + except SSL.Error as ssl_err: + # General SSL error - check for specific known errors + if not ssl_err.args: + raise + + error_list = ssl_err.args[0] + if not isinstance(error_list, list): + error_list = [error_list] + + for error_tuple in error_list: + if ( + not isinstance(error_tuple, tuple) + or len(error_tuple) < 3 + ): + continue + + error_message = error_tuple[2].lower() + + # HTTP request on HTTPS port + if 'http request' in error_message: + raise errors.NoSSLError + + # Fatal SSL alert + if 'alert' in error_message: + raise errors.FatalSSLAlert(str(ssl_err)) + + # Unknown SSL error + raise + + # ================================================================ + # Socket I/O methods with error handling + # ================================================================ + + def readable(self): + """Return True - this I/O object supports reading.""" + return True + + def writable(self): + """Return True - this I/O object supports writing.""" + return True + + def seekable(self): + """Return False - sockets are not seekable.""" + return False + + def recv(self, size): + """Receive data from the connection with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call(True, self._ssl_socket.recv, size) + # For ssl.SSLSocket, just call recv directly + # (it handles its own errors) + return self._ssl_socket.recv(size) + + def send(self, data, flags=0): + """Send data with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call(False, self._ssl_socket.send, data, flags) + return self._ssl_socket.send(data, flags) + + def sendall(self, data, flags=0): + """Send all data with SSL error handling.""" + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + return self._safe_call( + False, + self._ssl_socket.sendall, + data, + flags, + ) + return self._ssl_socket.sendall(data, flags) + + def readinto(self, buff): + """ + Read data into a buffer - called by :class:`io.BufferedReader`. + + This is the key method that ``BufferedReader`` calls when reading. + By implementing this with error handling, we ensure SSL errors + are properly handled in the buffered I/O path. + + Args: + buff: Buffer to read data into (bytearray or memoryview) + + Returns: + Number of bytes read, or None for EOF + """ + data = self.recv(len(buff)) + if not data: + return 0 # EOF + num_bytes = len(data) + view = memoryview(buff) + view[:num_bytes] = data + return num_bytes + + def read(self, size): + """Read data from the connection. Used by StreamReader.""" + return self.recv(size) + + def write(self, data): + """Write data to the connection with SSL error handling.""" + return self.send(data) + + # ================================================================ + # Unified SSL/TLS methods - handle both backends + # ================================================================ + + def get_cipher_info(self): + """ + Get the current cipher information in a unified format. + + Returns: + tuple: (cipher_name, protocol_version, secret_bits) or None + """ + ssl_socket = self._ssl_socket + + if isinstance(ssl_socket, ssl.SSLSocket): + # Returns tuple: (cipher_name, protocol_version, secret_bits) + return ssl_socket.cipher() + + if SSL and isinstance(ssl_socket, ssl_conn_type): + try: + protocol_constant = ssl_socket.get_protocol_version() + protocol = _OPENSSL_PROTOCOL_MAP.get( + protocol_constant, + 'UNKNOWN', + ) + cipher_name = ssl_socket.get_cipher_name() + active_bits = ssl_socket.get_cipher_bits() + + return (cipher_name, protocol, active_bits) + + except SSL.Error: + return None + except Exception as err: + raise errors.FatalSSLAlert( + 'Error retrieving cipher info from pyOpenSSL connection: ' + + str(err), + ) from err + + return None + + def getpeercert(self, binary_form=False): + """ + Get the peer's certificate. + + Args: + binary_form: If True, return DER-encoded bytes; + else return dict/object + + Returns: + Certificate in requested format, or None/empty dict if unavailable + """ + if not hasattr(self, '_ssl_socket') or self._ssl_socket is None: + return None if binary_form else {} + + # Handle PyOpenSSL Connection + if SSL and isinstance(self._ssl_socket, ssl_conn_type): + try: + cert = self._ssl_socket.get_peer_certificate() + if cert is None: + return None if binary_form else {} + + if binary_form: + return crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) + return cert + except Exception: + return None if binary_form else {} + + # Handle builtin ssl.SSLSocket + return self._ssl_socket.getpeercert(binary_form) + + def get_verify_mode(self): + """ + Get the certificate verification mode. + + Returns: + int: ssl.CERT_NONE, ssl.CERT_OPTIONAL, or ssl.CERT_REQUIRED + """ + ssl_socket = self._ssl_socket + + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.context.verify_mode + + if SSL and isinstance(ssl_socket, ssl_conn_type): + verify_mode = self.context.get_verify_mode() + + # Map PyOpenSSL constants to ssl module constants + if verify_mode == SSL.VERIFY_NONE: + return ssl.CERT_NONE + if verify_mode & SSL.VERIFY_PEER: + # Check if cert is actually required or just optional + if verify_mode & SSL.VERIFY_FAIL_IF_NO_PEER_CERT: + return ssl.CERT_REQUIRED # Cert must be provided + return ssl.CERT_OPTIONAL # Cert requested but not required + + return ssl.CERT_NONE + + # ================================================================ + # Socket control methods - explicit delegation + # ================================================================ + + def fileno(self): + """Return the file descriptor of the underlying socket.""" + if self._is_closed: + raise ValueError('Socket is closed') + if self._sock is None: + raise ValueError('Socket is not initialized') + + fd = self._sock.fileno() + if fd == -1: + raise OSError('Socket has been closed') + + return fd + + def getpeername(self): + """Return the address of the remote peer.""" + return self._sock.getpeername() + + def getsockname(self): + """Return the address of the local machine.""" + return self._sock.getsockname() + + def gettimeout(self): + """Get the timeout value.""" + return self._sock.gettimeout() + + def settimeout(self, timeout): + """Set timeout on the connection.""" + return self._ssl_socket.settimeout(timeout) + + def setblocking(self, flag): + """Set blocking mode.""" + return self._sock.setblocking(flag) + + def getsockopt(self, level, optname, buflen=None): + """Get socket option.""" + if buflen is not None: + return self._sock.getsockopt(level, optname, buflen) + return self._sock.getsockopt(level, optname) + + def makefile(self, *args, **kwargs): + """Create a file-like object from the connection.""" + return self._ssl_socket.makefile(*args, **kwargs) + + def shutdown(self, how): + """Perform a clean SSL shutdown.""" + if isinstance(self._ssl_socket, ssl.SSLSocket): + with suppress(Exception): + self._ssl_socket.unwrap() + + with suppress(Exception): + return self._sock.shutdown(how) + + def sock_shutdown(self, how): + """Shutdown the raw socket (TCP level), bypassing SSL shutdown.""" + # Windows error code for "not a socket" + WSAENOTSOCK = 10038 + try: + # Attempt to shutdown the underlying kernel socket + return self._sock.shutdown(how) + except OSError as err: + # errno 9 is EBADF (Bad file descriptor) + if err.errno in {errno.EBADF, WSAENOTSOCK}: + # The underlying socket was already closed, + # which is fine during cleanup. + # Silently ignore the error. + return None + # If it's another OSError, re-raise it. + raise + + def close(self): + """Close the connection.""" + if self._is_closed: + return + + self._is_closed = True + with suppress(Exception): + self._ssl_socket.close() + + # ================================================================ + # Additional SSL methods that might be called + # ================================================================ + + def compression(self): + """Get compression method (usually None for modern TLS).""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.compression() + # pyOpenSSL doesn't support this easily, return None + return None + + @property + def sni(self): + """Get SNI hostname if available.""" + ssl_socket = self._ssl_socket + + # SSLSocket doesn't expose SNI directly, return None + if isinstance(ssl_socket, ssl.SSLSocket): + return None + + # pyOpenSSL might have it via get_servername() + if SSL and isinstance(ssl_socket, ssl_conn_type): + with suppress(Exception): + return ssl_socket.get_servername() + return None + + def version(self): + """Get TLS version.""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.version() + if SSL and isinstance(ssl_socket, ssl_conn_type): + # Return string version + protocol_constant = ssl_socket.get_protocol_version() + return _OPENSSL_PROTOCOL_MAP.get(protocol_constant, 'UNKNOWN') + return None + + def get_session(self): + """Get SSL session for reuse (method form).""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.session # Property on SSLSocket + if SSL and isinstance(ssl_socket, ssl_conn_type): + return ssl_socket.get_session() # Method on pyOpenSSL + return None + + @property + def session(self): + """Get SSL session for reuse.""" + ssl_socket = self._ssl_socket + if isinstance(ssl_socket, ssl.SSLSocket): + return ssl_socket.session + if SSL and isinstance(ssl_socket, ssl_conn_type): + return ssl_socket.get_session() + return None diff --git a/cheroot/test/conftest.py b/cheroot/test/conftest.py index 657997c3f2..8b534e7806 100644 --- a/cheroot/test/conftest.py +++ b/cheroot/test/conftest.py @@ -9,9 +9,20 @@ import pytest -from .._compat import IS_MACOS, IS_WINDOWS +import trustme +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import ( + BestAvailableEncryption, + Encoding, + PrivateFormat, + load_pem_private_key, +) + +from .._compat import IS_MACOS, IS_WINDOWS, ntou from ..server import Gateway, HTTPServer from ..testing import ( # noqa: F401 # pylint: disable=unused-import + ANY_INTERFACE_IPV4, + _get_conn_data, get_server_client, native_server, thread_and_native_server, @@ -107,3 +118,90 @@ def make_http_server(bind_addr): time.sleep(0.1) return httpserver + + +@pytest.fixture +def ca(): + """Provide a certificate authority via fixture.""" + return trustme.CA() + + +@pytest.fixture +def tls_ca_certificate_pem_path(ca): + """Provide a certificate authority certificate file via fixture.""" + with ca.cert_pem.tempfile() as ca_cert_pem: + yield ca_cert_pem + + +@pytest.fixture +def tls_certificate(ca): + """ + Generate a TLS server certificate for testing. + + Creates a certificate valid for 'test-server.local', 'localhost', + and '127.0.0.1' with ``CN=localhost``. + """ + interface, _host, _port = _get_conn_data(ANY_INTERFACE_IPV4) + identities = [ + 'test-server.local', + 'localhost', # This will be used for CN and SAN + ntou(interface), # This is '127.0.0.1' for SAN + ] + return ca.issue_server_cert(*identities, common_name='localhost') + + +@pytest.fixture +def tls_certificate_pem_path(tls_certificate): + """ + Return path to temp file containing the server certificate in PEM format. + + The file is automatically cleaned up after the test completes. + """ + # The 'cert_pem' property holds the leaf certificate data. + leaf_cert_blob = tls_certificate.cert_chain_pems[0] + + # Write to a file that persists for the test duration + with leaf_cert_blob.tempfile() as cert_pem_path: + yield cert_pem_path + + +@pytest.fixture +def tls_certificate_chain_pem_path(tls_certificate): + """Provide a certificate chain PEM file path via fixture.""" + with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_certificate_private_key_pem_path(tls_certificate): + """Provide a certificate private key PEM file path via fixture.""" + with tls_certificate.private_key_pem.tempfile() as cert_key_pem: + yield cert_key_pem + + +@pytest.fixture +def tls_certificate_passwd_private_key_pem_path( + tls_certificate, + private_key_password, + tmp_path, +): + """Return a certificate private key PEM file path.""" + key_as_bytes = tls_certificate.private_key_pem.bytes() + private_key_object = load_pem_private_key( + key_as_bytes, + password=None, + backend=default_backend(), + ) + + encrypted_key_as_bytes = private_key_object.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.PKCS8, + encryption_algorithm=BestAvailableEncryption( + password=private_key_password.encode('utf-8'), + ), + ) + + key_file = tmp_path / 'encrypted-private-key.pem' + key_file.write_bytes(encrypted_key_as_bytes) + + return key_file diff --git a/cheroot/test/ssl/test_ssl_builtin.py b/cheroot/test/ssl/test_ssl_builtin.py new file mode 100644 index 0000000000..80c22dd66b --- /dev/null +++ b/cheroot/test/ssl/test_ssl_builtin.py @@ -0,0 +1,443 @@ +"""Tests for ``cheroot.ssl.builtin``.""" + +import socket +import ssl +import threading +import time +from contextlib import closing + +import pytest + +from cheroot import errors +from cheroot.makefile import StreamReader, StreamWriter +from cheroot.ssl.builtin import BuiltinSSLAdapter +from cheroot.ssl.tls_socket import TLSSocket + + +_CONNECTION_TIMEOUT_SECONDS = 5.0 +_SOCKET_BUFFER_SIZE = 4096 + + +@pytest.mark.usefixtures('mocker') +def test_full_builtin_environ_population( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that builtin adapter populates all SSL environ variables correctly.""" + captured = {'environ': {}} + + def capture_wsgi_app(environ, start_response): + captured['environ'].update(environ) + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'Hello, secure world!'] + + bind_host = '127.0.0.1' + port = 0 + + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + + from cheroot.wsgi import Server as WSGIServer + + server = WSGIServer( + bind_addr=(bind_host, port), + wsgi_app=capture_wsgi_app, + ) + server.ssl_adapter = adapter + + server.prepare() + actual_port = server.bind_addr[1] + + server_thread = threading.Thread(target=server.serve, daemon=True) + server_thread.start() + + time.sleep(1) + + try: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + sock = socket.create_connection( + (bind_host, actual_port), + timeout=_CONNECTION_TIMEOUT_SECONDS, + ) + client_sock = context.wrap_socket(sock, server_hostname=bind_host) + + request = ( + b'GET /test/path?q=1 HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n\r\n' + ) + client_sock.sendall(request) + + client_sock.settimeout(0.5) + response = b'' + try: + while True: + chunk = client_sock.recv(_SOCKET_BUFFER_SIZE) + if not chunk: + break + response += chunk + except socket.timeout: + # Possible timeout when the server closes + # the connection after sending the response. + pass + + client_sock.close() + + finally: + time.sleep(0.5) + server.stop() + server_thread.join(timeout=2) + + captured_environ = captured['environ'] + + # HTTP Request variables + assert captured_environ.get('REQUEST_METHOD') == 'GET' + assert captured_environ.get('PATH_INFO') == '/test/path' + assert captured_environ.get('QUERY_STRING') == 'q=1' + assert captured_environ.get('SERVER_PROTOCOL') == 'HTTP/1.1' + + # SSL variables + assert 'SSL_PROTOCOL' in captured_environ + assert 'TLS' in captured_environ['SSL_PROTOCOL'] + + assert 'SSL_CIPHER' in captured_environ + + assert 'SSL_VERSION_LIBRARY' in captured_environ + assert 'OpenSSL' in captured_environ['SSL_VERSION_LIBRARY'] + + assert 'SSL_VERSION_INTERFACE' in captured_environ + assert 'Python' in captured_environ['SSL_VERSION_INTERFACE'] + + assert 'SSL_SERVER_M_SERIAL' in captured_environ + + assert 'SSL_SERVER_S_DN_CN' in captured_environ + assert captured_environ['SSL_SERVER_S_DN_CN'] == 'localhost' + + +@pytest.mark.usefixtures('mocker') +def test_wrap_with_builtin_ssl_wrap_fails( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that SSL wrap_socket error raises FatalSSLAlert.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + try: + + def failing_wrap_socket(sock, *args, **kwargs): + raise ssl.SSLError('SSL wrap failed') + + mocker.patch.object( + adapter.context, + 'wrap_socket', + side_effect=failing_wrap_socket, + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating SSL socket', + ): + adapter._wrap_with_builtin(server_sock) + + finally: + server_sock.close() + client_sock.close() + + +@pytest.mark.parametrize( + ('error_class', 'error_msg', 'expected_exception', 'match_text'), + ( + ( + ssl.SSLWantReadError, + 'The operation did not complete', + None, # Should retry, not raise + None, + ), + ( + ssl.SSLWantWriteError, + 'The operation did not complete', + None, # Should retry, not raise + None, + ), + ( + OSError, + 'Connection reset by peer', + errors.FatalSSLAlert, + 'TCP error during handshake', + ), + ( + ssl.SSLError, + 'wrong version number', + errors.NoSSLError, + 'Client sent plain HTTP request', + ), + ( + ssl.SSLError, + 'certificate verify failed', + errors.FatalSSLAlert, + 'Fatal SSL error during handshake', + ), + ( + ssl.SSLEOFError, + 'EOF occurred', + errors.NoSSLError, + 'Peer closed connection during handshake', + ), + ), + ids=[ + 'SSLWantReadError_retry', + 'SSLWantWriteError_retry', + 'OSError_connection_reset', + 'SSLError_wrong_version', + 'SSLError_cert_failed', + 'SSLEOFError', + ], +) +def test_builtin_handshake_error_handling( # pylint: disable=too-many-positional-arguments + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, + error_class, + error_msg, + expected_exception, + match_text, +): + """Test various error conditions during builtin SSL handshake.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + try: + handshake_attempts = {'count': 0} + + def failing_do_handshake(): + handshake_attempts['count'] += 1 + if handshake_attempts['count'] == 1: + raise error_class(error_msg) + # Second attempt succeeds (for retry cases) + + # Mock wrap_socket to return a mock SSL socket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.do_handshake = failing_do_handshake + mock_ssl_socket.context = adapter.context + + mocker.patch.object( + adapter.context, + 'wrap_socket', + return_value=mock_ssl_socket, + ) + + # Mock select for WantRead/WantWrite cases + if error_class in {ssl.SSLWantReadError, ssl.SSLWantWriteError}: + mocker.patch( + 'select.select', + return_value=( + [server_sock.fileno()], + [server_sock.fileno()], + [], + ), + ) + + if expected_exception: + with pytest.raises(expected_exception, match=match_text): + adapter._wrap_with_builtin(server_sock) + else: + # Should retry and succeed + tls_sock = adapter._wrap_with_builtin(server_sock) + assert isinstance(tls_sock, TLSSocket) + assert handshake_attempts['count'] >= 2, ( + f'Expected retry but only {handshake_attempts["count"]} attempts' + ) + + finally: + server_sock.close() + client_sock.close() + + +def test_builtin_handshake_timeout( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that handshake times out on repeated WantRead errors.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + with closing(server_sock), closing(client_sock): + # Mock wrap_socket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_ssl_socket.do_handshake.side_effect = ssl.SSLWantReadError() + + mocker.patch.object( + adapter.context, + 'wrap_socket', + return_value=mock_ssl_socket, + ) + + # Mock select to return empty (timeout) + mocker.patch('select.select', return_value=([], [], [])) + + with pytest.raises(TimeoutError, match='Handshake failed'): + adapter._wrap_with_builtin(server_sock) + + +def test_builtin_create_ssl_socket_error( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that wrap_socket errors are caught and converted to FatalSSLAlert.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + # Create the context first so it has valid certs + adapter.context = adapter._create_context() + + server_sock, client_sock = socket.socketpair() + + with closing(server_sock), closing(client_sock): + # Now patch wrap_socket AFTER context is created + mocker.patch.object( + adapter.context, + 'wrap_socket', + side_effect=ssl.SSLError('Failed to create SSL socket'), + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating SSL socket', + ): + adapter._wrap_with_builtin(server_sock) + + +def test_builtin_adapter_get_server_cert_environ_no_cert( + tls_certificate_private_key_pem_path, +): + """Test that _get_server_cert_environ returns empty dict when no certificate.""" + adapter = BuiltinSSLAdapter( + None, # No certificate + tls_certificate_private_key_pem_path, + ) + + environ = adapter._get_server_cert_environ() + assert len(environ) == 0 + + +def test_builtin_adapter_get_server_cert_environ_invalid_file( + tls_certificate_private_key_pem_path, +): + """Test that _get_server_cert_environ handles invalid cert file gracefully.""" + with pytest.raises( + FileNotFoundError, + match='SSL certificate file not found', + ): + BuiltinSSLAdapter( + '/nonexistent/cert.pem', + tls_certificate_private_key_pem_path, + ) + + +def test_builtin_adapter_client_cert_no_verification( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that client cert environ is not populated when verification is disabled.""" + adapter = BuiltinSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + + # Mock TLSSocket with no client cert verification + mock_conn = mocker.MagicMock() + + # IMPORTANT: Mock _sock, not _ssl_socket! + mock_conn._sock.getpeercert = mocker.MagicMock( + side_effect=lambda binary_form=False: None, + ) + mock_conn.context.verify_mode = ssl.CERT_NONE + + ssl_environ = {'HTTPS': 'on'} + environ = adapter._get_client_cert_environ(mock_conn, ssl_environ) + + assert environ['SSL_CLIENT_VERIFY'] == 'NONE' + assert 'SSL_CLIENT_CERT' not in environ + + +def test_streamreader_with_tls_and_regular_sockets(mocker): + """Test StreamReader works with both TLSSocket and regular sockets.""" + # Test with TLSSocket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mocker.MagicMock(), + context=mocker.MagicMock(), + ) + + buffered_reader = StreamReader(mock_tls_socket) + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + assert hasattr(buffered_reader, 'read') + + # Test with regular socket + regular_sock = mocker.MagicMock() + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + buffered_reader = StreamReader(regular_sock) + + # Verify SocketIO was called to wrap regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'rb') + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + + +def test_streamwriter_with_tls_and_regular_sockets(mocker): + """Test StreamWriter works with both TLSSocket and regular sockets.""" + # Test with TLSSocket + mock_ssl_socket = mocker.MagicMock(spec=ssl.SSLSocket) + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mocker.MagicMock(), + context=mocker.MagicMock(), + ) + + buffered_writer = StreamWriter(mock_tls_socket) + assert isinstance(buffered_writer, StreamWriter) + assert buffered_writer.bytes_written == 0 + assert hasattr(buffered_writer, 'write') + + # Test with regular socket + regular_sock = mocker.MagicMock() + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + buffered_writer = StreamWriter(regular_sock) + + # Verify SocketIO was called to wrap regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'wb') + assert isinstance(buffered_writer, StreamWriter) + assert buffered_writer.bytes_written == 0 diff --git a/cheroot/test/ssl/test_ssl_pyopenssl.py b/cheroot/test/ssl/test_ssl_pyopenssl.py new file mode 100644 index 0000000000..f7fbfb0d03 --- /dev/null +++ b/cheroot/test/ssl/test_ssl_pyopenssl.py @@ -0,0 +1,703 @@ +"""Tests for ``cheroot.ssl.pyopenssl``.""" +# Assuming OpenSSL is imported as 'SSL' in your test module + +import errno +import io +import socket +import ssl +import threading +import time +from contextlib import suppress + +import pytest + +from OpenSSL import SSL + +from cheroot import errors +from cheroot.makefile import StreamReader, StreamWriter +from cheroot.ssl.pyopenssl import pyOpenSSLAdapter +from cheroot.ssl.tls_socket import TLSSocket, ssl_conn_type + +# --- The Main Integration Test --- +from cheroot.wsgi import Server as WSGIServer + + +_CONNECTION_TIMEOUT_SECONDS = 5.0 +_SOCKET_BUFFER_SIZE = 4096 + + +@pytest.mark.usefixtures('mocker') +def test_full_pyopenssl_environ_population( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """ + Test pyOpenSSL adapter populates WSGI environ with HTTP and SSL variables. + + Performs an end-to-end test by: + - Starting a WSGI server with pyOpenSSL TLS adapter + - Making an SSL client connection + - Sending an HTTP request over TLS + - Verifying environ contains correct HTTP vars (METHOD, PATH, QUERY, etc.) + - Verifying environ contains SSL vars (PROTOCOL, CIPHER, VERSION, DN, etc.) + + This ensures the pyOpenSSL integration correctly exposes SSL connection + details to WSGI applications through the environ dictionary. + """ + captured = {'environ': {}} + + def capture_wsgi_app(environ, start_response): + captured['environ'].update(environ) + start_response('200 OK', [('Content-Type', 'text/plain')]) + return [b'Hello, secure world!'] + + bind_host = '127.0.0.1' + port = 0 + + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ciphers='ALL', + ) + + # Use WSGIServer instead of HTTPServer + server = WSGIServer( + bind_addr=(bind_host, port), + wsgi_app=capture_wsgi_app, + ) + server.ssl_adapter = adapter + + # Prepare the server (binds the socket) + server.prepare() + actual_port = server.bind_addr[1] + + # Start the server in a thread + server_thread = threading.Thread(target=server.serve, daemon=True) + server_thread.start() + + time.sleep(1) # Give it time to start + + try: + # Connect using SSL + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + # === CLIENT SIDE === + # 1. Creating socket connection + sock = socket.create_connection( + (bind_host, actual_port), + timeout=_CONNECTION_TIMEOUT_SECONDS, + ) + + # 2. Wrapping with SSL and completing handshake + client_sock = context.wrap_socket(sock, server_hostname=bind_host) + + # 3. Create request + request = ( + b'GET /test/path?q=1 HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n\r\n' + ) + + # 4. Send request + client_sock.sendall(request) + + # 5. Read response + client_sock.settimeout(0.5) + response = b'' + try: + while True: + chunk = client_sock.recv(_SOCKET_BUFFER_SIZE) + if not chunk: + break + response += chunk + except socket.timeout: + # Possible timeout when the server closes + # the connection after sending the response. + pass + + client_sock.close() + + finally: + # Give server time to process + time.sleep(0.5) + + server.stop() + server_thread.join(timeout=2) + + captured_environ = captured['environ'] + + # Assertions - HTTP Request variables should be populated + assert captured_environ.get('REQUEST_METHOD') == 'GET', ( + f"Expected REQUEST_METHOD='GET', got '{captured_environ.get('REQUEST_METHOD')}'" + ) + + assert captured_environ.get('PATH_INFO') == '/test/path', ( + f"Expected PATH_INFO='/test/path', got '{captured_environ.get('PATH_INFO')}'" + ) + + assert captured_environ.get('QUERY_STRING') == 'q=1', ( + f"Expected empty QUERY_STRING, got '{captured_environ.get('QUERY_STRING')}'" + ) + + assert captured_environ.get('SERVER_PROTOCOL') == 'HTTP/1.1', ( + f"Expected SERVER_PROTOCOL='HTTP/1.1', got '{captured_environ.get('SERVER_PROTOCOL')}'" + ) + + # SSL variables should be populated + assert 'SSL_PROTOCOL' in captured_environ, 'SSL_PROTOCOL not in environ' + assert 'TLSv1' in captured_environ['SSL_PROTOCOL'] + + assert 'SSL_CIPHER' in captured_environ, 'SSL_CIPHER not in environ' + + assert 'SSL_VERSION_LIBRARY' in captured_environ + assert 'OpenSSL' in captured_environ['SSL_VERSION_LIBRARY'] + + assert 'SSL_VERSION_INTERFACE' in captured_environ + assert 'pyOpenSSL' in captured_environ['SSL_VERSION_INTERFACE'] + + assert 'SSL_SERVER_M_SERIAL' in captured_environ + + assert 'SSL_SERVER_S_DN_CN' in captured_environ, ( + 'SSL_SERVER_S_DN_CN not in environ' + ) + assert captured_environ['SSL_SERVER_S_DN_CN'] == 'localhost', ( + f"Expected CN='localhost', got '{captured_environ['SSL_SERVER_S_DN_CN']}'" + ) + + +# --- Parameterized Test Cases --- +test_cases = [ + ( + 'Success after WantReadError', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.WantReadError, b'OK'], + None, + b'OK', + 2, + ), + ( + 'Success after WantWriteError', + 'send', # Changed from 'safe_send' + 'send', + [SSL.WantWriteError, b'OK'], + None, + b'OK', + 2, + ), + ( + 'Timeout', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.WantReadError] * 4, + socket.timeout, + None, + 2, + ), + ( + 'SysCallError: Unexpected EOF', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(-1, 'Unexpected EOF')], + None, + b'', + 1, + ), + ( + 'SysCallError: Ignorable Socket Error (e.g., Broken Pipe)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(errno.EPIPE, 'Broken pipe')], + None, + b'', + 1, + ), + ( + 'SysCallError: Non-Ignorable Error (Connection Reset)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.SysCallError(errno.ENOTCONN, 'Socket is not connected')], + socket.error, + None, + 1, + ), + ( + 'SysCallError: Non-Ignorable Error (Writer)', + 'send', # Changed from 'safe_send' + 'send', + [SSL.SysCallError(999, 'Fatal system error')], + socket.error, + None, + 1, + ), + ( + 'SSL.Error: NoSSLError (HTTP Request)', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.Error([(-1, 'SSL routines', 'http request')])], + errors.NoSSLError, + None, + 1, + ), + ( + 'SSL.Error: Fatal SSL Alert', + 'recv', # Changed from 'safe_recv' + 'recv', + [SSL.Error([(-1, 'SSL routines', 'generic alert')])], + errors.FatalSSLAlert, + None, + 1, + ), +] + + +@pytest.mark.parametrize( + ( + 'test_case_name', + 'call_method', + 'mock_target', + 'side_effects', + 'expected_exception', + 'expected_result', + 'expected_call_count', + ), + test_cases, + ids=[case[0] for case in test_cases], # Extract names at module level +) +def test_safe_call_coverage( # pylint: disable=too-many-positional-arguments + test_case_name, + call_method, + mock_target, + side_effects, + expected_exception, + expected_result, + expected_call_count, + mocker, +): + """Test all critical success, retry, and error-mapping paths in _safe_call.""" + # Create mocks + mock_ssl_conn = mocker.MagicMock() + mock_ssl_conn.__class__ = ssl_conn_type + getattr(mock_ssl_conn, mock_target).side_effect = side_effects + + mock_raw_socket = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + mock_context = mocker.MagicMock() + + # Configure time mocks for the timeout case + mock_sleep = mocker.patch('time.sleep') + + if expected_exception is socket.timeout: + start_time = 1000.0 + mocker.patch('time.time').side_effect = [ + start_time, # Start time + start_time + 0.05, # First check (0.05 < 0.1) + start_time + 0.15, # Second check (0.15 > 0.1) -> Timeout + ] + else: + mocker.patch('time.time').return_value = 1000.0 + + # Create TLSSocket + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Call the method + call_func = getattr(tls_socket, call_method) + if expected_exception: + # Test Case expects an exception + with pytest.raises(expected_exception): + # Pass dummy args for recv/send + call_func(1024) if call_method == 'recv' else call_func( + b'test data', + ) + + else: + # Test Case expects a successful return + actual_result = ( + call_func(1024) + if call_method == 'recv' + else call_func(b'test data') + ) + assert actual_result == expected_result + + # Final check on call count + assert ( + getattr(mock_ssl_conn, mock_target).call_count == expected_call_count + ) + + # Ensure sleep was called if errors occurred that require retries + if expected_exception is not None and expected_exception not in { + socket.error, + errors.NoSSLError, + errors.FatalSSLAlert, + }: + assert mock_sleep.called + + +def test_tlssocket_is_readable(mocker): + """Test that TLSSocket properly declares itself as readable.""" + mock_ssl_conn = mocker.MagicMock(spec=SSL.Connection) + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + assert isinstance(tls_socket, io.RawIOBase) + + # The real test - if these are False, our methods aren't being called + if tls_socket.readable() is False: + pytest.fail( + 'TLSSocket.readable() is not working - check if changes were applied to the actual file', + ) + + assert tls_socket.readable() is True + assert hasattr(tls_socket, 'readinto') + + +@pytest.mark.parametrize( + ( + 'io_method', + 'error_class', + 'call_target', + 'test_input', + 'expected_output', + ), + ( + ('readinto', SSL.WantReadError, 'recv', 100, b'Hello World'), + ('write', SSL.WantWriteError, 'send', b'Test data', 9), + ), + ids=['readinto_WantReadError', 'write_WantWriteError'], +) +def test_tlssocket_io_handles_want_errors( # pylint: disable=too-many-positional-arguments + mocker, + io_method, + error_class, + call_target, + test_input, + expected_output, +): + """Test that TLSSocket I/O methods handle WantRead/WantWrite errors with retry.""" + # Setup mocks + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + mock_raw_socket.gettimeout.return_value = 1.0 + + # Track call count + call_count = {'count': 0} + + # Configure mock to raise error first, then succeed + def io_operation(*args, **kwargs): + call_count['count'] += 1 + + if call_count['count'] == 1: + raise error_class() + if call_target == 'recv': + return expected_output + # send + return len(args[0]) if args else expected_output + + # Attach mock to the appropriate method + setattr(mock_ssl_conn, call_target, io_operation) + + # Create TLSSocket + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Execute the I/O operation + if io_method == 'readinto': + buffer = bytearray(test_input) + bytes_read = tls_socket.readinto(buffer) + assert bytes(buffer[:bytes_read]) == expected_output + assert bytes_read == len(expected_output) + else: # write + bytes_written = tls_socket.write(test_input) + assert bytes_written == expected_output + + # Verify retry happened + assert call_count['count'] == 2, ( + f'Should have retried after {error_class.__name__}' + ) + + +def test_tlssocket_readinto_handles_syscallerror_eof(mocker): + """Test that TLSSocket.readinto() handles SysCallError with Unexpected EOF.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + # SysCallError with Unexpected EOF should return empty + mock_ssl_conn.recv = mocker.MagicMock( + side_effect=SSL.SysCallError(-1, 'Unexpected EOF'), + ) + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + buffer = bytearray(100) + bytes_read = tls_socket.readinto(buffer) + + assert bytes_read == 0, 'Unexpected EOF should return 0 bytes (EOF)' + + +def test_tlssocket_with_buffered_reader(mocker): + """Test that TLSSocket works correctly with :class:`io.BufferedReader`.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + call_count = {'count': 0} + + def recv_with_retry(size): + call_count['count'] += 1 + if call_count['count'] == 1: + raise SSL.WantReadError + if call_count['count'] == 2: + return b'Data from BufferedReader' + # Return empty to signal EOF for subsequent calls + return b'' + + mock_ssl_conn.recv = recv_with_retry + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Create BufferedReader with TLSSocket as the raw I/O + reader = io.BufferedReader(tls_socket, buffer_size=8192) + + # Read through BufferedReader + read_data = reader.read(100) + + assert read_data == b'Data from BufferedReader' + assert call_count['count'] >= 2, 'Should have retried after WantReadError' + + +def test_tlssocket_timeout_on_repeated_errors(mocker): + """Test that repeated SSL errors eventually timeout.""" + mock_ssl_conn = mocker.MagicMock(spec=ssl_conn_type) + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + # Always raise WantReadError + mock_ssl_conn.recv = mocker.MagicMock(side_effect=SSL.WantReadError()) + mock_raw_socket.gettimeout.return_value = 1.0 + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + tls_socket.ssl_retry_max = 0.05 # Short timeout for testing + + buffer = bytearray(100) + + with pytest.raises(socket.timeout): + tls_socket.readinto(buffer) + + +def test_tlssocket_sock_shutdown(mocker): + """Test that sock_shutdown calls the raw socket's shutdown method.""" + mock_ssl_conn = mocker.MagicMock() + mock_ssl_conn.__class__ = ssl_conn_type + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + tls_socket = TLSSocket(mock_ssl_conn, mock_raw_socket, mock_context) + + # Call sock_shutdown + tls_socket.sock_shutdown(socket.SHUT_RDWR) + + # Verify it called the raw socket's shutdown, not the SSL connection's + mock_raw_socket.shutdown.assert_called_once_with(socket.SHUT_RDWR) + mock_ssl_conn.shutdown.assert_not_called() + + +@pytest.mark.usefixtures('mocker') +def test_wrap_with_pyopenssl_ssl_connection_creation_fails( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, +): + """Test that SSL.Connection creation error raises FatalSSLAlert.""" + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter.get_context() + + # Use a real socket + server_sock, client_sock = socket.socketpair() + + try: + # Patch where it's imported IN THE MODULE + call_count = {'count': 0} + + def failing_connection(context, sock): + call_count['count'] += 1 + raise SSL.Error([(-1, 'SSL routines', 'initialization failed')]) + + # Patch ssl_conn_type, not SSL.Connection + mocker.patch( + 'cheroot.ssl.pyopenssl.ssl_conn_type', + side_effect=failing_connection, + ) + + with pytest.raises( + errors.FatalSSLAlert, + match='Error creating pyOpenSSL connection', + ): + adapter._wrap_with_pyopenssl(server_sock) + + finally: + server_sock.close() + client_sock.close() + + +@pytest.mark.parametrize( + ('error_class', 'error_args', 'expected_exception', 'match_text'), + ( + ( + SSL.WantReadError, + None, + None, # Should retry, not raise + None, + ), + ( + OSError, + (errno.ECONNRESET, 'Connection reset by peer'), + errors.FatalSSLAlert, + 'TCP error during handshake', + ), + ( + SSL.Error, + [(-1, 'SSL routines', 'http request')], + errors.NoSSLError, + 'Client sent plain HTTP request', + ), + ( + SSL.Error, + [(-1, 'SSL routines', 'certificate verify failed')], + errors.FatalSSLAlert, + 'Fatal SSL error during handshake', + ), + ( + SSL.ZeroReturnError, # NEW + None, + errors.NoSSLError, + 'Peer closed connection during handshake', + ), + ), + ids=[ + 'WantReadError_retry', + 'OSError_ECONNRESET', + 'SSL_http_request', + 'SSL_cert_failed', + 'ZeroReturnError', + ], +) +def test_handshake_error_handling( # pylint: disable=too-many-positional-arguments + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + mocker, + error_class, + error_args, + expected_exception, + match_text, +): + """Test various error conditions during SSL handshake.""" + adapter = pyOpenSSLAdapter( + tls_certificate_pem_path, + tls_certificate_private_key_pem_path, + ) + adapter.context = adapter.get_context() + + server_sock, client_sock = socket.socketpair() + + try: + original_connection = SSL.Connection + handshake_attempts = {'count': 0} + + def patched_connection(context, sock): + conn = original_connection(context, sock) + + def failing_handshake(): + handshake_attempts['count'] += 1 + if handshake_attempts['count'] == 1: + if error_args: + raise ( + error_class(*error_args) + if isinstance(error_args, tuple) + else error_class(error_args) + ) + raise error_class() + # Second attempt for retry cases + raise SSL.Error([(-1, 'SSL routines', 'test completed')]) + + conn.do_handshake = failing_handshake + return conn + + mocker.patch( + 'cheroot.ssl.pyopenssl.ssl_conn_type', + side_effect=patched_connection, + ) + mocker.patch( + 'select.select', + return_value=([server_sock.fileno()], [], []), + ) + + if expected_exception: + with pytest.raises(expected_exception, match=match_text): + adapter._wrap_with_pyopenssl(server_sock) + else: + with suppress(Exception): + adapter._wrap_with_pyopenssl(server_sock) + assert handshake_attempts['count'] >= 2 + + finally: + server_sock.close() + client_sock.close() + + +def test_streamreader_with_tls_socket(mocker): + """Test StreamReader works correctly with TLSSocket.""" + # Setup a TLSSocket instance + mock_ssl_socket = mocker.MagicMock() + mock_raw_socket = mocker.MagicMock() + mock_context = mocker.MagicMock() + + mock_tls_socket = TLSSocket( + ssl_socket=mock_ssl_socket, + raw_socket=mock_raw_socket, + context=mock_context, + ) + + # Create StreamReader with TLSSocket + buffered_reader = StreamReader(mock_tls_socket, bufsize=4096) + + # Assert it's a StreamReader instance + assert isinstance(buffered_reader, StreamReader) + assert buffered_reader.bytes_read == 0 + + # Verify TLSSocket was used directly (not wrapped with SocketIO) + # The _wrapped attribute would be the TLSSocket itself + assert buffered_reader.raw is mock_tls_socket + + +def test_streamwriter_with_regular_socket(mocker): + """Test StreamWriter works correctly with regular socket.""" + regular_sock = mocker.MagicMock() + + # Mock SocketIO to verify it's called for regular sockets + mock_socket_io = mocker.patch('socket.SocketIO') + mock_socket_io.return_value = mocker.MagicMock() + + writer = StreamWriter(regular_sock, bufsize=1024) + + # Verify SocketIO was called to wrap the regular socket + mock_socket_io.assert_called_once_with(regular_sock, 'wb') + + # Verify it's a StreamWriter instance + assert isinstance(writer, StreamWriter) + assert writer.bytes_written == 0 diff --git a/cheroot/test/test_makefile.py b/cheroot/test/test_makefile.py index d65d4ea268..c6fba66405 100644 --- a/cheroot/test/test_makefile.py +++ b/cheroot/test/test_makefile.py @@ -1,45 +1,46 @@ """Tests for :py:mod:`cheroot.makefile`.""" -from cheroot import makefile +import io +from cheroot.makefile import StreamReader, StreamWriter -class MockSocket: - """A mock socket.""" + +class MockSocket(io.RawIOBase): + """A mock socket for testing stream I/O.""" def __init__(self): - """Initialize :py:class:`MockSocket`.""" + """Initialize MockSocket.""" + super().__init__() self.messages = [] - def recv_into(self, buf): - """Simulate ``recv_into`` for Python 3.""" - if not self.messages: - return 0 - msg = self.messages.pop(0) - for index, byte in enumerate(msg): - buf[index] = byte - return len(msg) + def readable(self): + """Return True - supports reading.""" + return True - def recv(self, size): - """Simulate ``recv`` for Python 2.""" - try: - return self.messages.pop(0) - except IndexError: - return '' + def writable(self): + """Return True - supports writing.""" + return True - def send(self, val): - """Simulate a send.""" - return len(val) + def readinto(self, buf): + """Read data into buffer.""" + if not self.messages: + return 0 # EOF - def _decref_socketios(self): - """Emulate socket I/O reference decrement.""" - # Ref: https://github.com/cherrypy/cheroot/issues/734 + msg = self.messages.pop(0) + num_bytes = min(len(msg), len(buf)) + buf[:num_bytes] = msg[:num_bytes] # noqa: WPS362 + return num_bytes + + def write(self, data): + """Write data (returns length written).""" + return len(data) def test_bytes_read(): """Reader should capture bytes read.""" sock = MockSocket() sock.messages.append(b'foo') - rfile = makefile.MakeFile(sock, 'r') + rfile = StreamReader(sock) rfile.read() assert rfile.bytes_read == 3 @@ -47,7 +48,6 @@ def test_bytes_read(): def test_bytes_written(): """Writer should capture bytes written.""" sock = MockSocket() - sock.messages.append(b'foo') - wfile = makefile.MakeFile(sock, 'w') + wfile = StreamWriter(sock) wfile.write(b'bar') assert wfile.bytes_written == 3 diff --git a/cheroot/test/test_ssl.py b/cheroot/test/test_ssl.py index 115237f4f7..152267a5f9 100644 --- a/cheroot/test/test_ssl.py +++ b/cheroot/test/test_ssl.py @@ -32,9 +32,7 @@ IS_LINUX, IS_MACOS, IS_PYPY, - IS_SOLARIS, IS_WINDOWS, - SYS_PLATFORM, bton, ntob, ntou, @@ -723,36 +721,6 @@ def test_http_over_https_error( tls_certificate_private_key_pem_path, ): """Ensure that connecting over HTTP to HTTPS port is handled.""" - # disable some flaky tests - # https://github.com/cherrypy/cheroot/issues/225 - issue_225 = IS_MACOS and adapter_type == 'builtin' - if issue_225: - pytest.xfail('Test fails in Travis-CI') - - if IS_LINUX: - expected_error_code, expected_error_text = ( - 104, - 'Connection reset by peer', - ) - elif IS_MACOS: - expected_error_code, expected_error_text = ( - 54, - 'Connection reset by peer', - ) - elif IS_SOLARIS: - expected_error_code, expected_error_text = ( - None, - 'Remote end closed connection without response', - ) - elif IS_WINDOWS: - expected_error_code, expected_error_text = ( - 10054, - 'An existing connection was forcibly closed by the remote host', - ) - else: - expected_error_code, expected_error_text = None, None - pytest.skip(f'{SYS_PLATFORM} is unsupported') # pragma: no cover - tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) tls_adapter = tls_adapter_cls( tls_certificate_chain_pem_path, @@ -774,31 +742,15 @@ def test_http_over_https_error( if ip_addr is ANY_INTERFACE_IPV6: fqdn = '[{fqdn}]'.format(**locals()) - expect_fallback_response_over_plain_http = adapter_type == 'pyopenssl' - if expect_fallback_response_over_plain_http: - resp = requests.get( - f'http://{fqdn!s}:{port!s}/', - timeout=http_request_timeout, - ) - assert resp.status_code == 400 - assert resp.text == ( - 'The client sent a plain HTTP request, ' - 'but this server only speaks HTTPS on this port.' - ) - return - - with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: - requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL - f'http://{fqdn!s}:{port!s}/', - timeout=http_request_timeout, - ) - - underlying_error = ssl_err.value.args[0].args[-1] - err_text = str(underlying_error) - assert underlying_error.errno == expected_error_code, ( - 'The underlying error is {underlying_error!r}'.format(**locals()) + resp = requests.get( + f'http://{fqdn!s}:{port!s}/', + timeout=http_request_timeout, + ) + assert resp.status_code == 400 + assert resp.text == ( + 'The client sent a plain HTTP request, ' + 'but this server only speaks HTTPS on this port.' ) - assert expected_error_text in err_text @pytest.mark.parametrize('adapter_type', ('builtin', 'pyopenssl')) @@ -851,6 +803,8 @@ def test_ssl_adapters_with_private_key_password( ).bind_addr, ) + time.sleep(2) # allow time for the server to start + resp = requests.get( f'https://{interface!s}:{port!s}/', timeout=http_request_timeout, diff --git a/cheroot/wsgi.py b/cheroot/wsgi.py index 4ac4b16764..154818f484 100644 --- a/cheroot/wsgi.py +++ b/cheroot/wsgi.py @@ -228,6 +228,8 @@ def write(self, chunk): # Dang. We have probably already sent data. Truncate the chunk # to fit (so the client doesn't hang) and raise an error later. chunk = chunk[:rbo] + # Update chunklen to match the truncated size + chunklen = len(chunk) self.req.ensure_headers_sent() diff --git a/docs/changelog-fragments.d/799.contrib.rst b/docs/changelog-fragments.d/799.contrib.rst new file mode 100644 index 0000000000..f2068194fb --- /dev/null +++ b/docs/changelog-fragments.d/799.contrib.rst @@ -0,0 +1,3 @@ +Added :py:class:`cheroot.ssl.tls_socket.TLSSocket` to provide uniform SSL/TLS handling + +-- by :user:`julianz-`. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1e1022c82f..7d23c08f6e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -2,6 +2,7 @@ AppVeyor Args backend backports +bufsize bugfixes builtin b'xb @@ -46,12 +47,14 @@ preconfigure py pytest pythonic +pyOpenSSL readonly rebase Refactor Refactored refactoring refactorings +seekable Sep sep signalling @@ -75,6 +78,7 @@ tuples unbuffered unclosed unfilterable +unhandled unregister unregisters uptime diff --git a/stubtest_allowlist.txt b/stubtest_allowlist.txt index cc6e94a5a5..b2ab351a6a 100644 --- a/stubtest_allowlist.txt +++ b/stubtest_allowlist.txt @@ -1,40 +1,3 @@ -# generated members by metaclass -cheroot.ssl.pyopenssl.SSLConnection.accept -cheroot.ssl.pyopenssl.SSLConnection.bind -cheroot.ssl.pyopenssl.SSLConnection.close -cheroot.ssl.pyopenssl.SSLConnection.connect -cheroot.ssl.pyopenssl.SSLConnection.connect_ex -cheroot.ssl.pyopenssl.SSLConnection.family -cheroot.ssl.pyopenssl.SSLConnection.fileno -cheroot.ssl.pyopenssl.SSLConnection.get_app_data -cheroot.ssl.pyopenssl.SSLConnection.get_cipher_list -cheroot.ssl.pyopenssl.SSLConnection.get_context -cheroot.ssl.pyopenssl.SSLConnection.get_peer_certificate -cheroot.ssl.pyopenssl.SSLConnection.getpeername -cheroot.ssl.pyopenssl.SSLConnection.getsockname -cheroot.ssl.pyopenssl.SSLConnection.getsockopt -cheroot.ssl.pyopenssl.SSLConnection.gettimeout -cheroot.ssl.pyopenssl.SSLConnection.listen -cheroot.ssl.pyopenssl.SSLConnection.makefile -cheroot.ssl.pyopenssl.SSLConnection.pending -cheroot.ssl.pyopenssl.SSLConnection.read -cheroot.ssl.pyopenssl.SSLConnection.recv -cheroot.ssl.pyopenssl.SSLConnection.renegotiate -cheroot.ssl.pyopenssl.SSLConnection.send -cheroot.ssl.pyopenssl.SSLConnection.sendall -cheroot.ssl.pyopenssl.SSLConnection.set_accept_state -cheroot.ssl.pyopenssl.SSLConnection.set_app_data -cheroot.ssl.pyopenssl.SSLConnection.set_connect_state -cheroot.ssl.pyopenssl.SSLConnection.setblocking -cheroot.ssl.pyopenssl.SSLConnection.setsockopt -cheroot.ssl.pyopenssl.SSLConnection.settimeout -cheroot.ssl.pyopenssl.SSLConnection.shutdown -cheroot.ssl.pyopenssl.SSLConnection.sock_shutdown -cheroot.ssl.pyopenssl.SSLConnection.state_string -cheroot.ssl.pyopenssl.SSLConnection.want_read -cheroot.ssl.pyopenssl.SSLConnection.want_write -cheroot.ssl.pyopenssl.SSLConnection.write - # suppress is both a function and class cheroot._compat.suppress