Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dev = [
"websocket-client>=1.7.0",
"coverage[toml]>=5.0.0",
"coveralls>=3.3",
"localstack-twisted",
"twisted>=24",
"ruff==0.1.0"
]
docs = [
Expand Down
62 changes: 47 additions & 15 deletions rolo/serving/twisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing as t
from io import BytesIO
from queue import Empty, Queue
from typing import Iterator, Sequence, Tuple, Union

from twisted.internet import reactor
from twisted.internet.protocol import Protocol
Expand Down Expand Up @@ -36,7 +37,6 @@
if t.TYPE_CHECKING:
from _typeshed.wsgi import WSGIEnvironment


LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -148,25 +148,46 @@ def to_websocket_environment(request: Request) -> WebSocketEnvironment:
return environ


class TwistedHeaderAdapter(TwistedHeaders):
"""
Custom twisted server Headers object to handle header casing. This was introduced to abstract away the refactoring
that happened in https://github.com/twisted/twisted/pull/12264.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._caseMappings = {}

def rememberHeaderCasing(self, name: Union[str, bytes]) -> None:
"""
Receives a raw header in its original casing and stores it to later restore the header casing in
``getAllRawHeaders``.
"""
self._caseMappings[name.lower()] = name

def getAllRawHeaders(self) -> Iterator[Tuple[bytes, Sequence[bytes]]]:
for k, v in self._rawHeaders.items():
yield self._caseMappings.get(k.lower(), k), v


class TwistedRequestAdapter(TwistedRequest):
"""
Custom twisted server Request object to handle header casing.
"""

rawHeaderList: list[tuple[bytes, bytes]]
requestHeaders: TwistedHeaderAdapter
responseHeaders: TwistedHeaderAdapter

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# instantiate case mappings, these are used by `getAllRawHeaders` to restore casing
# by default, they are class attributes, so we would override them globally
self.requestHeaders._caseMappings = dict(self.requestHeaders._caseMappings)
self.responseHeaders._caseMappings = dict(self.responseHeaders._caseMappings)
self.requestHeaders = TwistedHeaderAdapter()
self.responseHeaders = TwistedHeaderAdapter()


class HeaderPreservingHTTPChannel(HTTPChannel):
"""
Special HTTPChannel implementation that uses ``Headers._caseMappings`` to retain header casing both for
request headers (server -> WSGI), and response headers (WSGI -> client).
request headers (server -> WSGI), and response headers (WSGI -> client).
"""

requestFactory = TwistedRequestAdapter
Expand All @@ -178,20 +199,30 @@ def protocol_factory():
def headerReceived(self, line):
if not super().headerReceived(line):
return False
# remember casing of headers for requests
# remember casing of headers for requests, note that this will only work if TwistedRequestAdapter is used
# as the Request object type, which requires a correct setup of the `Site` object.
header, data = line.split(b":", 1)
request: TwistedRequestAdapter = self.requests[-1]
request.requestHeaders._caseMappings[header.lower()] = header
request.requestHeaders.rememberHeaderCasing(header)
return True

def writeHeaders(self, version, code, reason, headers):
def writeHeaders(
self, version: bytes, code: bytes, reason: bytes, headers: list | TwistedHeaders
):
"""Alternative implementation that writes the raw headers instead of sanitized versions."""
responseLine = version + b" " + code + b" " + reason + b"\r\n"
headerSequence = [responseLine]

for name, value in headers:
line = name + b": " + value + b"\r\n"
headerSequence.append(line)
if isinstance(headers, list):
# older twisted versions sometime before 24.10 passed a list to this method
for name, value in headers:
line = name + b": " + value + b"\r\n"
headerSequence.append(line)
else:
# newer twisted versions instead pass the headers object
for name, values in headers.getAllRawHeaders():
line = name + b": " + b",".join(values) + b"\r\n"
headerSequence.append(line)

headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)
Expand All @@ -216,7 +247,7 @@ def startResponse(self, *args, **kwargs):
# headers
for header, _ in self.headers:
header = header.encode("latin-1")
self.request.responseHeaders._caseMappings[header.lower()] = header
self.request.responseHeaders.rememberHeaderCasing(header)
return result


Expand Down Expand Up @@ -441,6 +472,7 @@ class TwistedGateway(Site):

def __init__(self, gateway: Gateway):
super().__init__(
GatewayResource(gateway, reactor, reactor.getThreadPool()), TwistedRequestAdapter
resource=GatewayResource(gateway, reactor, reactor.getThreadPool()),
requestFactory=TwistedRequestAdapter,
)
self.protocol = HeaderPreservingHTTPChannel.protocol_factory
9 changes: 7 additions & 2 deletions rolo/testing/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server):
"""
from twisted.web.server import Site

from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator
from rolo.serving.twisted import (
HeaderPreservingWSGIResource,
TwistedRequestAdapter,
WebsocketResourceDecorator,
)

def _create(websocket_listener: WebSocketListener):
site = Site(
Expand All @@ -266,7 +270,8 @@ def _create(websocket_listener: WebSocketListener):
twisted_reactor, twisted_reactor.getThreadPool(), None
),
websocketListener=websocket_listener,
)
),
requestFactory=TwistedRequestAdapter,
)
site.protocol = HeaderPreservingHTTPChannel.protocol_factory
return serve_twisted_tcp_server(site)
Expand Down
23 changes: 18 additions & 5 deletions tests/gateway/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import pytest
import requests

from rolo import Response
from rolo.gateway import Gateway, HandlerChain, RequestContext


@pytest.mark.parametrize("serve_gateway", ["asgi", "twisted"], indirect=True)
def test_raw_header_handling(serve_gateway):
def handler(chain: HandlerChain, context: RequestContext, response):
def handler(chain: HandlerChain, context: RequestContext, response: Response):
response.data = json.dumps({"headers": dict(context.request.headers)})
response.mimetype = "application/json"
response.headers["X-fOO_bar"] = "FooBar"
response.headers["content-md5"] = "af5e58f9a7c4682e1b410f2e9392a539"
return response

gateway = Gateway(request_handlers=[handler])
Expand All @@ -22,7 +24,18 @@ def handler(chain: HandlerChain, context: RequestContext, response):
srv.url,
headers={"x-mIxEd-CaSe": "myheader", "X-UPPER__CASE": "uppercase"},
)
returned_headers = response.json()["headers"]
assert "X-UPPER__CASE" in returned_headers
assert "x-mIxEd-CaSe" in returned_headers
assert "X-fOO_bar" in dict(response.headers)
request_headers = response.json()["headers"]

# test default headers
assert "User-Agent" in request_headers
assert "Connection" in request_headers
assert "Host" in request_headers

# test custom headers
assert "X-UPPER__CASE" in request_headers
assert "x-mIxEd-CaSe" in request_headers

response_headers = dict(response.headers)
assert "X-fOO_bar" in response_headers
# even though it's a standard header, it should be in the original case
assert "content-md5" in response_headers