Skip to content

Commit b087246

Browse files
jekirltiangolo
authored andcommitted
✨ Add support for WebSockets with dependencies, params, etc fastapi#166 (fastapi#178)
1 parent 219d299 commit b087246

File tree

12 files changed

+305
-17
lines changed

12 files changed

+305
-17
lines changed

docs/src/websockets/__init__.py

Whitespace-only changes.

docs/src/websockets/tutorial001.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,9 @@ async def get():
4444
return HTMLResponse(html)
4545

4646

47-
@app.websocket_route("/ws")
47+
@app.websocket("/ws")
4848
async def websocket_endpoint(websocket: WebSocket):
4949
await websocket.accept()
5050
while True:
5151
data = await websocket.receive_text()
5252
await websocket.send_text(f"Message text was: {data}")
53-
await websocket.close()

docs/src/websockets/tutorial002.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from fastapi import Cookie, Depends, FastAPI, Header
2+
from starlette.responses import HTMLResponse
3+
from starlette.status import WS_1008_POLICY_VIOLATION
4+
from starlette.websockets import WebSocket
5+
6+
app = FastAPI()
7+
8+
html = """
9+
<!DOCTYPE html>
10+
<html>
11+
<head>
12+
<title>Chat</title>
13+
</head>
14+
<body>
15+
<h1>WebSocket Chat</h1>
16+
<form action="" onsubmit="sendMessage(event)">
17+
<label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
18+
<button onclick="connect(event)">Connect</button>
19+
<br>
20+
<label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
21+
<button>Send</button>
22+
</form>
23+
<ul id='messages'>
24+
</ul>
25+
<script>
26+
var ws = null;
27+
function connect(event) {
28+
var input = document.getElementById("itemId")
29+
ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
30+
ws.onmessage = function(event) {
31+
var messages = document.getElementById('messages')
32+
var message = document.createElement('li')
33+
var content = document.createTextNode(event.data)
34+
message.appendChild(content)
35+
messages.appendChild(message)
36+
};
37+
}
38+
function sendMessage(event) {
39+
var input = document.getElementById("messageText")
40+
ws.send(input.value)
41+
input.value = ''
42+
event.preventDefault()
43+
}
44+
</script>
45+
</body>
46+
</html>
47+
"""
48+
49+
50+
@app.get("/")
51+
async def get():
52+
return HTMLResponse(html)
53+
54+
55+
async def get_cookie_or_client(
56+
websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
57+
):
58+
if session is None and x_client is None:
59+
await websocket.close(code=WS_1008_POLICY_VIOLATION)
60+
return session or x_client
61+
62+
63+
@app.websocket("/items/{item_id}/ws")
64+
async def websocket_endpoint(
65+
websocket: WebSocket,
66+
item_id: int,
67+
q: str = None,
68+
cookie_or_client: str = Depends(get_cookie_or_client),
69+
):
70+
await websocket.accept()
71+
while True:
72+
data = await websocket.receive_text()
73+
await websocket.send_text(
74+
f"Session Cookie or X-Client Header value is: {cookie_or_client}"
75+
)
76+
if q is not None:
77+
await websocket.send_text(f"Query parameter q is: {q}")
78+
await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")

docs/tutorial/websockets.md

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
2727
{!./src/websockets/tutorial001.py!}
2828
```
2929

30-
## Create a `websocket_route`
30+
## Create a `websocket`
3131

32-
In your **FastAPI** application, create a `websocket_route`:
32+
In your **FastAPI** application, create a `websocket`:
3333

3434
```Python hl_lines="3 47 48"
3535
{!./src/websockets/tutorial001.py!}
@@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
3838
!!! tip
3939
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
4040

41-
That is not required, but it's recommended as it will provide you completion and checks inside the function.
42-
43-
44-
!!! info
45-
This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>.
46-
47-
That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
48-
49-
5041
## Await for messages and send messages
5142

5243
In your WebSocket route you can `await` for messages and send messages.
@@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
5748

5849
You can receive and send binary, text, and JSON data.
5950

51+
## Using `Depends` and others
52+
53+
In WebSocket endpoints you can import from `fastapi` and use:
54+
55+
* `Depends`
56+
* `Security`
57+
* `Cookie`
58+
* `Header`
59+
* `Path`
60+
* `Query`
61+
62+
They work the same way as for other FastAPI endpoints/*path operations*:
63+
64+
```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
65+
{!./src/websockets/tutorial002.py!}
66+
```
67+
68+
!!! info
69+
In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
70+
71+
You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
72+
73+
In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
74+
75+
## More info
76+
6077
To learn more about the options, check Starlette's documentation for:
6178

6279
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.

fastapi/applications.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ def decorator(func: Callable) -> Callable:
203203

204204
return decorator
205205

206+
def add_api_websocket_route(
207+
self, path: str, endpoint: Callable, name: str = None
208+
) -> None:
209+
self.router.add_api_websocket_route(path, endpoint, name=name)
210+
211+
def websocket(self, path: str, name: str = None) -> Callable:
212+
def decorator(func: Callable) -> Callable:
213+
self.add_api_websocket_route(path, func, name=name)
214+
return func
215+
216+
return decorator
217+
206218
def include_router(
207219
self,
208220
router: routing.APIRouter,

fastapi/dependencies/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
name: str = None,
2727
call: Callable = None,
2828
request_param_name: str = None,
29+
websocket_param_name: str = None,
2930
background_tasks_param_name: str = None,
3031
security_scopes_param_name: str = None,
3132
security_scopes: List[str] = None,
@@ -38,6 +39,7 @@ def __init__(
3839
self.dependencies = dependencies or []
3940
self.security_requirements = security_schemes or []
4041
self.request_param_name = request_param_name
42+
self.websocket_param_name = websocket_param_name
4143
self.background_tasks_param_name = background_tasks_param_name
4244
self.security_scopes = security_scopes
4345
self.security_scopes_param_name = security_scopes_param_name

fastapi/dependencies/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from starlette.concurrency import run_in_threadpool
3434
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
3535
from starlette.requests import Request
36+
from starlette.websockets import WebSocket
3637

3738
param_supported_types = (
3839
str,
@@ -184,6 +185,8 @@ def get_dependant(
184185
)
185186
elif lenient_issubclass(param.annotation, Request):
186187
dependant.request_param_name = param_name
188+
elif lenient_issubclass(param.annotation, WebSocket):
189+
dependant.websocket_param_name = param_name
187190
elif lenient_issubclass(param.annotation, BackgroundTasks):
188191
dependant.background_tasks_param_name = param_name
189192
elif lenient_issubclass(param.annotation, SecurityScopes):
@@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
279282

280283
async def solve_dependencies(
281284
*,
282-
request: Request,
285+
request: Union[Request, WebSocket],
283286
dependant: Dependant,
284287
body: Dict[str, Any] = None,
285288
background_tasks: BackgroundTasks = None,
@@ -326,8 +329,10 @@ async def solve_dependencies(
326329
)
327330
values.update(body_values)
328331
errors.extend(body_errors)
329-
if dependant.request_param_name:
332+
if dependant.request_param_name and isinstance(request, Request):
330333
values[dependant.request_param_name] = request
334+
elif dependant.websocket_param_name and isinstance(request, WebSocket):
335+
values[dependant.websocket_param_name] = request
331336
if dependant.background_tasks_param_name:
332337
if background_tasks is None:
333338
background_tasks = BackgroundTasks()

fastapi/routing.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import inspect
33
import logging
4+
import re
45
from typing import Any, Callable, Dict, List, Optional, Type, Union
56

67
from fastapi import params
@@ -21,8 +22,14 @@
2122
from starlette.exceptions import HTTPException
2223
from starlette.requests import Request
2324
from starlette.responses import JSONResponse, Response
24-
from starlette.routing import compile_path, get_name, request_response
25-
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
25+
from starlette.routing import (
26+
compile_path,
27+
get_name,
28+
request_response,
29+
websocket_session,
30+
)
31+
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
32+
from starlette.websockets import WebSocket
2633

2734

2835
def serialize_response(*, field: Field = None, response: Response) -> Any:
@@ -97,6 +104,35 @@ async def app(request: Request) -> Response:
97104
return app
98105

99106

107+
def get_websocket_app(dependant: Dependant) -> Callable:
108+
async def app(websocket: WebSocket) -> None:
109+
values, errors, _ = await solve_dependencies(
110+
request=websocket, dependant=dependant
111+
)
112+
if errors:
113+
await websocket.close(code=WS_1008_POLICY_VIOLATION)
114+
errors_out = ValidationError(errors)
115+
raise HTTPException(
116+
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
117+
)
118+
assert dependant.call is not None, "dependant.call must me a function"
119+
await dependant.call(**values)
120+
121+
return app
122+
123+
124+
class APIWebSocketRoute(routing.WebSocketRoute):
125+
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
126+
self.path = path
127+
self.endpoint = endpoint
128+
self.name = get_name(endpoint) if name is None else name
129+
self.dependant = get_dependant(path=path, call=self.endpoint)
130+
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
131+
regex = "^" + path + "$"
132+
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
133+
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
134+
135+
100136
class APIRoute(routing.Route):
101137
def __init__(
102138
self,
@@ -281,6 +317,19 @@ def decorator(func: Callable) -> Callable:
281317

282318
return decorator
283319

320+
def add_api_websocket_route(
321+
self, path: str, endpoint: Callable, name: str = None
322+
) -> None:
323+
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
324+
self.routes.append(route)
325+
326+
def websocket(self, path: str, name: str = None) -> Callable:
327+
def decorator(func: Callable) -> Callable:
328+
self.add_api_websocket_route(path, func, name=name)
329+
return func
330+
331+
return decorator
332+
284333
def include_router(
285334
self,
286335
router: "APIRouter",
@@ -326,6 +375,10 @@ def include_router(
326375
include_in_schema=route.include_in_schema,
327376
name=route.name,
328377
)
378+
elif isinstance(route, APIWebSocketRoute):
379+
self.add_api_websocket_route(
380+
prefix + route.path, route.endpoint, name=route.name
381+
)
329382
elif isinstance(route, routing.WebSocketRoute):
330383
self.add_websocket_route(
331384
prefix + route.path, route.endpoint, name=route.name

tests/test_tutorial/test_websockets/__init__.py

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
from starlette.testclient import TestClient
3+
from starlette.websockets import WebSocketDisconnect
4+
from websockets.tutorial001 import app
5+
6+
client = TestClient(app)
7+
8+
9+
def test_main():
10+
response = client.get("/")
11+
assert response.status_code == 200
12+
assert b"<!DOCTYPE html>" in response.content
13+
14+
15+
def test_websocket():
16+
with pytest.raises(WebSocketDisconnect):
17+
with client.websocket_connect("/ws") as websocket:
18+
message = "Message one"
19+
websocket.send_text(message)
20+
data = websocket.receive_text()
21+
assert data == f"Message text was: {message}"
22+
message = "Message two"
23+
websocket.send_text(message)
24+
data = websocket.receive_text()
25+
assert data == f"Message text was: {message}"

0 commit comments

Comments
 (0)