Skip to content

Commit 8c3ef76

Browse files
dmontagutiangolo
authored andcommitted
✨ Add better support for request body access/manipulation with custom classes (fastapi#589)
1 parent 7a504a7 commit 8c3ef76

File tree

10 files changed

+304
-14
lines changed

10 files changed

+304
-14
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import gzip
2+
from typing import Callable, List
3+
4+
from fastapi import Body, FastAPI
5+
from fastapi.routing import APIRoute
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
9+
10+
class GzipRequest(Request):
11+
async def body(self) -> bytes:
12+
if not hasattr(self, "_body"):
13+
body = await super().body()
14+
if "gzip" in self.headers.getlist("Content-Encoding"):
15+
body = gzip.decompress(body)
16+
self._body = body
17+
return self._body
18+
19+
20+
class GzipRoute(APIRoute):
21+
def get_route_handler(self) -> Callable:
22+
original_route_handler = super().get_route_handler()
23+
24+
async def custom_route_handler(request: Request) -> Response:
25+
request = GzipRequest(request.scope, request.receive)
26+
return await original_route_handler(request)
27+
28+
return custom_route_handler
29+
30+
31+
app = FastAPI()
32+
app.router.route_class = GzipRoute
33+
34+
35+
@app.post("/sum")
36+
async def sum_numbers(numbers: List[int] = Body(...)):
37+
return {"sum": sum(numbers)}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Callable, List
2+
3+
from fastapi import Body, FastAPI, HTTPException
4+
from fastapi.exceptions import RequestValidationError
5+
from fastapi.routing import APIRoute
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
9+
10+
class ValidationErrorLoggingRoute(APIRoute):
11+
def get_route_handler(self) -> Callable:
12+
original_route_handler = super().get_route_handler()
13+
14+
async def custom_route_handler(request: Request) -> Response:
15+
try:
16+
return await original_route_handler(request)
17+
except RequestValidationError as exc:
18+
body = await request.body()
19+
detail = {"errors": exc.errors(), "body": body.decode()}
20+
raise HTTPException(status_code=422, detail=detail)
21+
22+
return custom_route_handler
23+
24+
25+
app = FastAPI()
26+
app.router.route_class = ValidationErrorLoggingRoute
27+
28+
29+
@app.post("/")
30+
async def sum_numbers(numbers: List[int] = Body(...)):
31+
return sum(numbers)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import time
2+
from typing import Callable
3+
4+
from fastapi import APIRouter, FastAPI
5+
from fastapi.routing import APIRoute
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
9+
10+
class TimedRoute(APIRoute):
11+
def get_route_handler(self) -> Callable:
12+
original_route_handler = super().get_route_handler()
13+
14+
async def custom_route_handler(request: Request) -> Response:
15+
before = time.time()
16+
response: Response = await original_route_handler(request)
17+
duration = time.time() - before
18+
response.headers["X-Response-Time"] = str(duration)
19+
print(f"route duration: {duration}")
20+
print(f"route response: {response}")
21+
print(f"route response headers: {response.headers}")
22+
return response
23+
24+
return custom_route_handler
25+
26+
27+
app = FastAPI()
28+
router = APIRouter(route_class=TimedRoute)
29+
30+
31+
@app.get("/")
32+
async def not_timed():
33+
return {"message": "Not timed"}
34+
35+
36+
@router.get("/timed")
37+
async def timed():
38+
return {"message": "It's the time of my life"}
39+
40+
41+
app.include_router(router)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
In some cases, you may want to override the logic used by the `Request` and `APIRoute` classes.
2+
3+
In particular, this may be a good alternative to logic in a middleware.
4+
5+
For example, if you want to read or manipulate the request body before it is processed by your application.
6+
7+
!!! danger
8+
This is an "advanced" feature.
9+
10+
If you are just starting with **FastAPI** you might want to skip this section.
11+
12+
## Use cases
13+
14+
Some use cases include:
15+
16+
* Converting non-JSON request bodies to JSON (e.g. [`msgpack`](https://msgpack.org/index.html)).
17+
* Decompressing gzip-compressed request bodies.
18+
* Automatically logging all request bodies.
19+
* Accessing the request body in an exception handler.
20+
21+
## Handling custom request body encodings
22+
23+
Let's see how to make use of a custom `Request` subclass to decompress gzip requests.
24+
25+
And an `APIRoute` subclass to use that custom request class.
26+
27+
### Create a custom `GzipRequest` class
28+
29+
First, we create a `GzipRequest` class, which will overwrite the `Request.body()` method to decompress the body in the presence of an appropriate header.
30+
31+
If there's no `gzip` in the header, it will not try to decompress the body.
32+
33+
That way, the same route class can handle gzip compressed or uncompressed requests.
34+
35+
```Python hl_lines="10 11 12 13 14 15 16 17"
36+
{!./src/custom_request_and_route/tutorial001.py!}
37+
```
38+
39+
### Create a custom `GzipRoute` class
40+
41+
Next, we create a custom subclass of `fastapi.routing.APIRoute` that will make use of the `GzipRequest`.
42+
43+
This time, it will overwrite the method `APIRoute.get_route_handler()`.
44+
45+
This method returns a function. And that function is what will receive a request and return a response.
46+
47+
Here we use it to create a `GzipRequest` from the original request.
48+
49+
```Python hl_lines="20 21 22 23 24 25 26 27 28"
50+
{!./src/custom_request_and_route/tutorial001.py!}
51+
```
52+
53+
!!! note "Technical Details"
54+
A `Request` has a `request.scope` attribute, that's just a Python `dict` containing the metadata related to the request.
55+
56+
A `Request` also has a `request.receive`, that's a function to "receive" the body of the request.
57+
58+
The `scope` `dict` and `receive` function are both part of the ASGI specification.
59+
60+
And those two things, `scope` and `receive`, are what is needed to create a new `Request` instance.
61+
62+
To learn more about the `Request` check <a href="https://www.starlette.io/requests/" target="_blank">Starlette's docs about Requests</a>.
63+
64+
The only thing the function returned by `GzipRequest.get_route_handler` does differently is convert the `Request` to a `GzipRequest`.
65+
66+
Doing this, our `GzipRequest` will take care of decompressing the data (if necessary) before passing it to our *path operations*.
67+
68+
After that, all of the processing logic is the same.
69+
70+
But because of our changes in `GzipRequest.body`, the request body will be automatically decompressed when it is loaded by **FastAPI** when needed.
71+
72+
## Accessing the request body in an exception handler
73+
74+
We can also use this same approach to access the request body in an exception handler.
75+
76+
All we need to do is handle the request inside a `try`/`except` block:
77+
78+
```Python hl_lines="15 17"
79+
{!./src/custom_request_and_route/tutorial002.py!}
80+
```
81+
82+
If an exception occurs, the`Request` instance will still be in scope, so we can read and make use of the request body when handling the error:
83+
84+
```Python hl_lines="18 19 20"
85+
{!./src/custom_request_and_route/tutorial002.py!}
86+
```
87+
88+
## Custom `APIRoute` class in a router
89+
90+
You can also set the `route_class` parameter of an `APIRouter`:
91+
92+
```Python hl_lines="25"
93+
{!./src/custom_request_and_route/tutorial003.py!}
94+
```
95+
96+
In this example, the *path operations* under the `router` will use the custom `TimedRoute` class, and will have an extra `X-Response-Time` header in the response with the time it took to generate the response:
97+
98+
```Python hl_lines="15 16 17 18 19"
99+
{!./src/custom_request_and_route/tutorial003.py!}
100+
```

fastapi/routing.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def serialize_response(
6565
return jsonable_encoder(response)
6666

6767

68-
def get_app(
68+
def get_request_handler(
6969
dependant: Dependant,
7070
body_field: Field = None,
7171
status_code: int = 200,
@@ -294,19 +294,20 @@ def __init__(
294294
)
295295
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
296296
self.dependency_overrides_provider = dependency_overrides_provider
297-
self.app = request_response(
298-
get_app(
299-
dependant=self.dependant,
300-
body_field=self.body_field,
301-
status_code=self.status_code,
302-
response_class=self.response_class or JSONResponse,
303-
response_field=self.secure_cloned_response_field,
304-
response_model_include=self.response_model_include,
305-
response_model_exclude=self.response_model_exclude,
306-
response_model_by_alias=self.response_model_by_alias,
307-
response_model_skip_defaults=self.response_model_skip_defaults,
308-
dependency_overrides_provider=self.dependency_overrides_provider,
309-
)
297+
self.app = request_response(self.get_route_handler())
298+
299+
def get_route_handler(self) -> Callable:
300+
return get_request_handler(
301+
dependant=self.dependant,
302+
body_field=self.body_field,
303+
status_code=self.status_code,
304+
response_class=self.response_class or JSONResponse,
305+
response_field=self.secure_cloned_response_field,
306+
response_model_include=self.response_model_include,
307+
response_model_exclude=self.response_model_exclude,
308+
response_model_by_alias=self.response_model_by_alias,
309+
response_model_skip_defaults=self.response_model_skip_defaults,
310+
dependency_overrides_provider=self.dependency_overrides_provider,
310311
)
311312

312313

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ nav:
8181
- GraphQL: 'tutorial/graphql.md'
8282
- WebSockets: 'tutorial/websockets.md'
8383
- 'Events: startup - shutdown': 'tutorial/events.md'
84+
- Custom Request and APIRoute class: 'tutorial/custom-request-and-route.md'
8485
- Testing: 'tutorial/testing.md'
8586
- Testing Dependencies with Overrides: 'tutorial/testing-dependencies.md'
8687
- Debugging: 'tutorial/debugging.md'

tests/test_tutorial/test_custom_request_and_route/__init__.py

Whitespace-only changes.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import gzip
2+
import json
3+
4+
import pytest
5+
from starlette.requests import Request
6+
from starlette.testclient import TestClient
7+
8+
from custom_request_and_route.tutorial001 import app
9+
10+
11+
@app.get("/check-class")
12+
async def check_gzip_request(request: Request):
13+
return {"request_class": type(request).__name__}
14+
15+
16+
client = TestClient(app)
17+
18+
19+
@pytest.mark.parametrize("compress", [True, False])
20+
def test_gzip_request(compress):
21+
n = 1000
22+
headers = {}
23+
body = [1] * n
24+
data = json.dumps(body).encode()
25+
if compress:
26+
data = gzip.compress(data)
27+
headers["Content-Encoding"] = "gzip"
28+
response = client.post("/sum", data=data, headers=headers)
29+
assert response.json() == {"sum": n}
30+
31+
32+
def test_request_class():
33+
response = client.get("/check-class")
34+
assert response.json() == {"request_class": "GzipRequest"}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from starlette.testclient import TestClient
2+
3+
from custom_request_and_route.tutorial002 import app
4+
5+
client = TestClient(app)
6+
7+
8+
def test_endpoint_works():
9+
response = client.post("/", json=[1, 2, 3])
10+
assert response.json() == 6
11+
12+
13+
def test_exception_handler_body_access():
14+
response = client.post("/", json={"numbers": [1, 2, 3]})
15+
16+
assert response.json() == {
17+
"detail": {
18+
"body": '{"numbers": [1, 2, 3]}',
19+
"errors": [
20+
{
21+
"loc": ["body", "numbers"],
22+
"msg": "value is not a valid list",
23+
"type": "type_error.list",
24+
}
25+
],
26+
}
27+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from starlette.testclient import TestClient
2+
3+
from custom_request_and_route.tutorial003 import app
4+
5+
client = TestClient(app)
6+
7+
8+
def test_get():
9+
response = client.get("/")
10+
assert response.json() == {"message": "Not timed"}
11+
assert "X-Response-Time" not in response.headers
12+
13+
14+
def test_get_timed():
15+
response = client.get("/timed")
16+
assert response.json() == {"message": "It's the time of my life"}
17+
assert "X-Response-Time" in response.headers
18+
assert float(response.headers["X-Response-Time"]) > 0

0 commit comments

Comments
 (0)