Skip to content

Commit dd96351

Browse files
dmontagutiangolo
authored andcommitted
🐛 Fix preserving route_class when calling include_router (fastapi#538)
1 parent fdb6d43 commit dd96351

File tree

2 files changed

+118
-1
lines changed

2 files changed

+118
-1
lines changed

fastapi/routing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,10 @@ def add_api_route(
348348
include_in_schema: bool = True,
349349
response_class: Type[Response] = None,
350350
name: str = None,
351+
route_class_override: Optional[Type[APIRoute]] = None,
351352
) -> None:
352-
route = self.route_class(
353+
route_class = route_class_override or self.route_class
354+
route = route_class(
353355
path,
354356
endpoint=endpoint,
355357
response_model=response_model,
@@ -487,6 +489,7 @@ def include_router(
487489
include_in_schema=route.include_in_schema,
488490
response_class=route.response_class or default_response_class,
489491
name=route.name,
492+
route_class_override=type(route),
490493
)
491494
elif isinstance(route, routing.Route):
492495
self.add_route(

tests/test_custom_route_class.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pytest
2+
from fastapi import APIRouter, FastAPI
3+
from fastapi.routing import APIRoute
4+
from starlette.testclient import TestClient
5+
6+
app = FastAPI()
7+
8+
9+
class APIRouteA(APIRoute):
10+
x_type = "A"
11+
12+
13+
class APIRouteB(APIRoute):
14+
x_type = "B"
15+
16+
17+
class APIRouteC(APIRoute):
18+
x_type = "C"
19+
20+
21+
router_a = APIRouter(route_class=APIRouteA)
22+
router_b = APIRouter(route_class=APIRouteB)
23+
router_c = APIRouter(route_class=APIRouteC)
24+
25+
26+
@router_a.get("/")
27+
def get_a():
28+
return {"msg": "A"}
29+
30+
31+
@router_b.get("/")
32+
def get_b():
33+
return {"msg": "B"}
34+
35+
36+
@router_c.get("/")
37+
def get_c():
38+
return {"msg": "C"}
39+
40+
41+
router_b.include_router(router=router_c, prefix="/c")
42+
router_a.include_router(router=router_b, prefix="/b")
43+
app.include_router(router=router_a, prefix="/a")
44+
45+
46+
client = TestClient(app)
47+
48+
openapi_schema = {
49+
"openapi": "3.0.2",
50+
"info": {"title": "Fast API", "version": "0.1.0"},
51+
"paths": {
52+
"/a/": {
53+
"get": {
54+
"responses": {
55+
"200": {
56+
"description": "Successful Response",
57+
"content": {"application/json": {"schema": {}}},
58+
}
59+
},
60+
"summary": "Get A",
61+
"operationId": "get_a_a__get",
62+
}
63+
},
64+
"/a/b/": {
65+
"get": {
66+
"responses": {
67+
"200": {
68+
"description": "Successful Response",
69+
"content": {"application/json": {"schema": {}}},
70+
}
71+
},
72+
"summary": "Get B",
73+
"operationId": "get_b_a_b__get",
74+
}
75+
},
76+
"/a/b/c/": {
77+
"get": {
78+
"responses": {
79+
"200": {
80+
"description": "Successful Response",
81+
"content": {"application/json": {"schema": {}}},
82+
}
83+
},
84+
"summary": "Get C",
85+
"operationId": "get_c_a_b_c__get",
86+
}
87+
},
88+
},
89+
}
90+
91+
92+
@pytest.mark.parametrize(
93+
"path,expected_status,expected_response",
94+
[
95+
("/a", 200, {"msg": "A"}),
96+
("/a/b", 200, {"msg": "B"}),
97+
("/a/b/c", 200, {"msg": "C"}),
98+
("/openapi.json", 200, openapi_schema),
99+
],
100+
)
101+
def test_get_path(path, expected_status, expected_response):
102+
response = client.get(path)
103+
assert response.status_code == expected_status
104+
assert response.json() == expected_response
105+
106+
107+
def test_route_classes():
108+
routes = {}
109+
r: APIRoute
110+
for r in app.router.routes:
111+
routes[r.path] = r
112+
assert routes["/a/"].x_type == "A"
113+
assert routes["/a/b/"].x_type == "B"
114+
assert routes["/a/b/c/"].x_type == "C"

0 commit comments

Comments
 (0)