diff --git a/pyproject.toml b/pyproject.toml index 6f5636b..5f5b606 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dev = [ "websocket-client>=1.7.0", "coverage[toml]>=5.0.0", "coveralls>=3.3", - "localstack-twisted", + "twisted>=24", "ruff==0.1.0" ] docs = [ diff --git a/rolo/serving/twisted.py b/rolo/serving/twisted.py index 9dc3ded..2ae1f25 100644 --- a/rolo/serving/twisted.py +++ b/rolo/serving/twisted.py @@ -5,6 +5,7 @@ import typing as t from io import BytesIO from queue import Empty, Queue +from typing import Iterator, Sequence, Tuple, Union from twisted.internet import reactor from twisted.internet.protocol import Protocol @@ -36,7 +37,6 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment - LOG = logging.getLogger(__name__) @@ -148,25 +148,46 @@ def to_websocket_environment(request: Request) -> WebSocketEnvironment: return environ +class TwistedHeaderAdapter(TwistedHeaders): + """ + Custom twisted server Headers object to handle header casing. This was introduced to abstract away the refactoring + that happened in https://github.com/twisted/twisted/pull/12264. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._caseMappings = {} + + def rememberHeaderCasing(self, name: Union[str, bytes]) -> None: + """ + Receives a raw header in its original casing and stores it to later restore the header casing in + ``getAllRawHeaders``. + """ + self._caseMappings[name.lower()] = name + + def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]: + for k, v in self._rawHeaders.items(): + yield self._caseMappings.get(k.lower(), k), v + + class TwistedRequestAdapter(TwistedRequest): """ Custom twisted server Request object to handle header casing. """ - rawHeaderList: list[tuple[bytes, bytes]] + requestHeaders: TwistedHeaderAdapter + responseHeaders: TwistedHeaderAdapter def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # instantiate case mappings, these are used by `getAllRawHeaders` to restore casing - # by default, they are class attributes, so we would override them globally - self.requestHeaders._caseMappings = dict(self.requestHeaders._caseMappings) - self.responseHeaders._caseMappings = dict(self.responseHeaders._caseMappings) + self.requestHeaders = TwistedHeaderAdapter() + self.responseHeaders = TwistedHeaderAdapter() class HeaderPreservingHTTPChannel(HTTPChannel): """ Special HTTPChannel implementation that uses ``Headers._caseMappings`` to retain header casing both for - request headers (server -> WSGI), and response headers (WSGI -> client). + request headers (server -> WSGI), and response headers (WSGI -> client). """ requestFactory = TwistedRequestAdapter @@ -178,20 +199,30 @@ def protocol_factory(): def headerReceived(self, line): if not super().headerReceived(line): return False - # remember casing of headers for requests + # remember casing of headers for requests, note that this will only work if TwistedRequestAdapter is used + # as the Request object type, which requires a correct setup of the `Site` object. header, data = line.split(b":", 1) request: TwistedRequestAdapter = self.requests[-1] - request.requestHeaders._caseMappings[header.lower()] = header + request.requestHeaders.rememberHeaderCasing(header) return True - def writeHeaders(self, version, code, reason, headers): + def writeHeaders( + self, version: bytes, code: bytes, reason: bytes, headers: list | TwistedHeaders + ): """Alternative implementation that writes the raw headers instead of sanitized versions.""" responseLine = version + b" " + code + b" " + reason + b"\r\n" headerSequence = [responseLine] - for name, value in headers: - line = name + b": " + value + b"\r\n" - headerSequence.append(line) + if isinstance(headers, list): + # older twisted versions sometime before 24.10 passed a list to this method + for name, value in headers: + line = name + b": " + value + b"\r\n" + headerSequence.append(line) + else: + # newer twisted versions instead pass the headers object + for name, values in headers.getAllRawHeaders(): + line = name + b": " + b",".join(values) + b"\r\n" + headerSequence.append(line) headerSequence.append(b"\r\n") self.transport.writeSequence(headerSequence) @@ -216,7 +247,7 @@ def startResponse(self, *args, **kwargs): # headers for header, _ in self.headers: header = header.encode("latin-1") - self.request.responseHeaders._caseMappings[header.lower()] = header + self.request.responseHeaders.rememberHeaderCasing(header) return result @@ -441,6 +472,7 @@ class TwistedGateway(Site): def __init__(self, gateway: Gateway): super().__init__( - GatewayResource(gateway, reactor, reactor.getThreadPool()), TwistedRequestAdapter + resource=GatewayResource(gateway, reactor, reactor.getThreadPool()), + requestFactory=TwistedRequestAdapter, ) self.protocol = HeaderPreservingHTTPChannel.protocol_factory diff --git a/rolo/testing/pytest.py b/rolo/testing/pytest.py index 3ea7a4c..5ca2f76 100644 --- a/rolo/testing/pytest.py +++ b/rolo/testing/pytest.py @@ -257,7 +257,11 @@ def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server): """ from twisted.web.server import Site - from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator + from rolo.serving.twisted import ( + HeaderPreservingWSGIResource, + TwistedRequestAdapter, + WebsocketResourceDecorator, + ) def _create(websocket_listener: WebSocketListener): site = Site( @@ -266,7 +270,8 @@ def _create(websocket_listener: WebSocketListener): twisted_reactor, twisted_reactor.getThreadPool(), None ), websocketListener=websocket_listener, - ) + ), + requestFactory=TwistedRequestAdapter, ) site.protocol = HeaderPreservingHTTPChannel.protocol_factory return serve_twisted_tcp_server(site) diff --git a/tests/gateway/test_headers.py b/tests/gateway/test_headers.py index 864cbde..8e62d69 100644 --- a/tests/gateway/test_headers.py +++ b/tests/gateway/test_headers.py @@ -3,15 +3,17 @@ import pytest import requests +from rolo import Response from rolo.gateway import Gateway, HandlerChain, RequestContext @pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True) def test_raw_header_handling(serve_gateway): - def handler(chain: HandlerChain, context: RequestContext, response): + def handler(chain: HandlerChain, context: RequestContext, response: Response): response.data = json.dumps({"headers": dict(context.request.headers)}) response.mimetype = "application/json" response.headers["X-fOO_bar"] = "FooBar" + response.headers["content-md5"] = "af5e58f9a7c4682e1b410f2e9392a539" return response gateway = Gateway(request_handlers=[handler]) @@ -22,7 +24,18 @@ def handler(chain: HandlerChain, context: RequestContext, response): srv.url, headers={"x-mIxEd-CaSe": "myheader", "X-UPPER__CASE": "uppercase"}, ) - returned_headers = response.json()["headers"] - assert "X-UPPER__CASE" in returned_headers - assert "x-mIxEd-CaSe" in returned_headers - assert "X-fOO_bar" in dict(response.headers) + request_headers = response.json()["headers"] + + # test default headers + assert "User-Agent" in request_headers + assert "Connection" in request_headers + assert "Host" in request_headers + + # test custom headers + assert "X-UPPER__CASE" in request_headers + assert "x-mIxEd-CaSe" in request_headers + + response_headers = dict(response.headers) + assert "X-fOO_bar" in response_headers + # even though it's a standard header, it should be in the original case + assert "content-md5" in response_headers