Skip to content

Commit bbebfb0

Browse files
committed
restore original ASGITransport class code and split streaming-enabled version into separate ASGIStreamingTransport class
1 parent 04bd45b commit bbebfb0

File tree

3 files changed

+194
-65
lines changed

3 files changed

+194
-65
lines changed

httpx/_transports/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
__all__ = [
88
"ASGITransport",
9+
"ASGIStreamingTransport",
910
"AsyncBaseTransport",
1011
"BaseTransport",
1112
"AsyncHTTPTransport",

httpx/_transports/asgi.py

Lines changed: 151 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None]
3434
]
3535

36-
__all__ = ["ASGITransport"]
36+
__all__ = ["ASGITransport", "ASGIStreamingTransport"]
3737

3838

3939
def is_running_trio() -> bool:
@@ -98,6 +98,141 @@ def get_end_of_stream_error_type() -> type[anyio.EndOfStream | trio.EndOfChannel
9898

9999

100100
class ASGIResponseStream(AsyncByteStream):
101+
def __init__(self, body: list[bytes]) -> None:
102+
self._body = body
103+
104+
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
105+
yield b"".join(self._body)
106+
107+
108+
class ASGITransport(AsyncBaseTransport):
109+
"""
110+
A custom AsyncTransport that handles sending requests directly to an ASGI app.
111+
112+
```python
113+
transport = httpx.ASGITransport(
114+
app=app,
115+
root_path="/submount",
116+
client=("1.2.3.4", 123)
117+
)
118+
client = httpx.AsyncClient(transport=transport)
119+
```
120+
121+
Arguments:
122+
123+
* `app` - The ASGI application.
124+
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
125+
should be raised. Default to `True`. Can be set to `False` for use cases
126+
such as testing the content of a client 500 response.
127+
* `root_path` - The root path on which the ASGI application should be mounted.
128+
* `client` - A two-tuple indicating the client IP and port of incoming requests.
129+
```
130+
"""
131+
132+
def __init__(
133+
self,
134+
app: _ASGIApp,
135+
raise_app_exceptions: bool = True,
136+
root_path: str = "",
137+
client: tuple[str, int] = ("127.0.0.1", 123),
138+
) -> None:
139+
self.app = app
140+
self.raise_app_exceptions = raise_app_exceptions
141+
self.root_path = root_path
142+
self.client = client
143+
144+
async def handle_async_request(
145+
self,
146+
request: Request,
147+
) -> Response:
148+
assert isinstance(request.stream, AsyncByteStream)
149+
150+
# ASGI scope.
151+
scope = {
152+
"type": "http",
153+
"asgi": {"version": "3.0"},
154+
"http_version": "1.1",
155+
"method": request.method,
156+
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
157+
"scheme": request.url.scheme,
158+
"path": request.url.path,
159+
"raw_path": request.url.raw_path.split(b"?")[0],
160+
"query_string": request.url.query,
161+
"server": (request.url.host, request.url.port),
162+
"client": self.client,
163+
"root_path": self.root_path,
164+
}
165+
166+
# Request.
167+
request_body_chunks = request.stream.__aiter__()
168+
request_complete = False
169+
170+
# Response.
171+
status_code = None
172+
response_headers = None
173+
body_parts = []
174+
response_started = False
175+
response_complete = create_event()
176+
177+
# ASGI callables.
178+
179+
async def receive() -> dict[str, typing.Any]:
180+
nonlocal request_complete
181+
182+
if request_complete:
183+
await response_complete.wait()
184+
return {"type": "http.disconnect"}
185+
186+
try:
187+
body = await request_body_chunks.__anext__()
188+
except StopAsyncIteration:
189+
request_complete = True
190+
return {"type": "http.request", "body": b"", "more_body": False}
191+
return {"type": "http.request", "body": body, "more_body": True}
192+
193+
async def send(message: typing.MutableMapping[str, typing.Any]) -> None:
194+
nonlocal status_code, response_headers, response_started
195+
196+
if message["type"] == "http.response.start":
197+
assert not response_started
198+
199+
status_code = message["status"]
200+
response_headers = message.get("headers", [])
201+
response_started = True
202+
203+
elif message["type"] == "http.response.body":
204+
assert not response_complete.is_set()
205+
body = message.get("body", b"")
206+
more_body = message.get("more_body", False)
207+
208+
if body and request.method != "HEAD":
209+
body_parts.append(body)
210+
211+
if not more_body:
212+
response_complete.set()
213+
214+
try:
215+
await self.app(scope, receive, send)
216+
except Exception: # noqa: PIE-786
217+
if self.raise_app_exceptions:
218+
raise
219+
220+
response_complete.set()
221+
if status_code is None:
222+
status_code = 500
223+
if response_headers is None:
224+
response_headers = {}
225+
226+
assert response_complete.is_set()
227+
assert status_code is not None
228+
assert response_headers is not None
229+
230+
stream = ASGIResponseStream(body_parts)
231+
232+
return Response(status_code, headers=response_headers, stream=stream)
233+
234+
235+
class ASGIStreamingResponseStream(AsyncByteStream):
101236
def __init__(
102237
self,
103238
ignore_body: bool,
@@ -124,18 +259,19 @@ async def aclose(self) -> None:
124259
await self._asgi_generator.aclose()
125260

126261

127-
class ASGITransport(AsyncBaseTransport):
262+
class ASGIStreamingTransport(AsyncBaseTransport):
128263
"""
129-
A custom AsyncTransport that handles sending requests directly to an ASGI app.
264+
An equivalent of ASGITransport that operates by running app in a sub-task and
265+
streaming response events as soon as they arrive.
130266
131-
```python
132-
transport = httpx.ASGITransport(
133-
app=app,
134-
root_path="/submount",
135-
client=("1.2.3.4", 123)
136-
)
137-
client = httpx.AsyncClient(transport=transport)
138-
```
267+
It is used in the same way, with the same arguments having the same signification,
268+
as ASGITransport.
269+
270+
The main observable differences between the two implementations will be as follows:
271+
* As the application callable is invoked in a sub-task, any context variables that
272+
are set by the app will not propagate to the caller;
273+
* The streaming mode of operation means that a response will generally be returned
274+
to the AsyncClient caller before the application has fully run;
139275
140276
Arguments:
141277
@@ -145,9 +281,6 @@ class ASGITransport(AsyncBaseTransport):
145281
such as testing the content of a client 500 response.
146282
* `root_path` - The root path on which the ASGI application should be mounted.
147283
* `client` - A two-tuple indicating the client IP and port of incoming requests.
148-
* `streaming` - Set to `True` to enable streaming of response content. Default to
149-
`False`, as activating this feature means that the ASGI `app` will run in a
150-
sub-task, which has observable side effects for context variables.
151284
```
152285
"""
153286

@@ -157,14 +290,11 @@ def __init__(
157290
raise_app_exceptions: bool = True,
158291
root_path: str = "",
159292
client: tuple[str, int] = ("127.0.0.1", 123),
160-
*,
161-
streaming: bool = False,
162293
) -> None:
163294
self.app = app
164295
self.raise_app_exceptions = raise_app_exceptions
165296
self.root_path = root_path
166297
self.client = client
167-
self.streaming = streaming
168298

169299
async def handle_async_request(
170300
self,
@@ -177,7 +307,7 @@ async def handle_async_request(
177307
return Response(
178308
status_code=message["status"],
179309
headers=message.get("headers", []),
180-
stream=ASGIResponseStream(
310+
stream=ASGIStreamingResponseStream(
181311
ignore_body=request.method == "HEAD",
182312
asgi_generator=asgi_generator,
183313
),
@@ -214,9 +344,8 @@ async def _stream_asgi_messages(
214344
response_complete = create_event()
215345

216346
# ASGI response messages stream
217-
stream_size = 0 if self.streaming else float("inf")
218347
response_message_send_stream, response_message_recv_stream = (
219-
create_memory_object_stream(stream_size)
348+
create_memory_object_stream(0)
220349
)
221350

222351
# ASGI app exception
@@ -256,11 +385,8 @@ async def run_app() -> None:
256385

257386
async with contextlib.AsyncExitStack() as exit_stack:
258387
exit_stack.callback(response_complete.set)
259-
if self.streaming:
260-
task_group = await exit_stack.enter_async_context(create_task_group())
261-
task_group.start_soon(run_app)
262-
else:
263-
await run_app()
388+
task_group = await exit_stack.enter_async_context(create_task_group())
389+
task_group.start_soon(run_app)
264390

265391
async with response_message_recv_stream:
266392
try:

0 commit comments

Comments
 (0)