Skip to content

Commit 830ca94

Browse files
committed
WIP mask
1 parent f9074e3 commit 830ca94

File tree

5 files changed

+55
-22
lines changed

5 files changed

+55
-22
lines changed

strawberry/extensions/mask_errors.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections.abc import Iterator
2-
from typing import Callable
2+
from typing import Any, Callable
33

44
from graphql.error import GraphQLError
5+
from graphql.execution import ExecutionResult
56

67
from strawberry.extensions.base_extension import SchemaExtension
78

@@ -33,18 +34,30 @@ def anonymise_error(self, error: GraphQLError) -> GraphQLError:
3334
original_error=None,
3435
)
3536

37+
# TODO: proper typing
38+
def _process_result(self, result: Any) -> None:
39+
if not result.errors:
40+
return
41+
42+
processed_errors: list[GraphQLError] = []
43+
44+
for error in result.errors:
45+
if self.should_mask_error(error):
46+
processed_errors.append(self.anonymise_error(error))
47+
else:
48+
processed_errors.append(error)
49+
50+
result.errors = processed_errors
51+
3652
def on_operation(self) -> Iterator[None]:
3753
yield
54+
3855
result = self.execution_context.result
39-
if result and result.errors:
40-
processed_errors: list[GraphQLError] = []
41-
for error in result.errors:
42-
if self.should_mask_error(error):
43-
processed_errors.append(self.anonymise_error(error))
44-
else:
45-
processed_errors.append(error)
46-
47-
result.errors = processed_errors
56+
57+
if isinstance(result, ExecutionResult):
58+
self._process_result(result)
59+
else:
60+
self._process_result(result.initial_result)
4861

4962

5063
__all__ = ["MaskErrors"]

strawberry/schema/_graphql_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ExperimentalIncrementalExecutionResults as GraphQLIncrementalExecutionResults,
1212
)
1313
from graphql.execution import ( # type: ignore[attr-defined]
14+
InitialIncrementalExecutionResult,
1415
experimental_execute_incrementally,
1516
)
1617
from graphql.type.directives import ( # type: ignore[attr-defined]
@@ -23,6 +24,10 @@
2324
GraphQLStreamDirective,
2425
)
2526

27+
GraphQLExecutionResult = Union[
28+
GraphQLExecutionResult, InitialIncrementalExecutionResult
29+
]
30+
2631
except ImportError:
2732
GraphQLIncrementalExecutionResults = type(None)
2833

strawberry/types/execution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from typing_extensions import NotRequired
1919

2020
from graphql import ASTValidationRule
21-
from graphql import ExecutionResult as GraphQLExecutionResult
2221
from graphql.error.graphql_error import GraphQLError
2322
from graphql.language import DocumentNode, OperationDefinitionNode
2423

2524
from strawberry.schema import Schema
25+
from strawberry.schema._graphql_core import GraphQLExecutionResult
2626

2727
from .graphql import OperationType
2828

tests/http/incremental/test_defer.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ async def test_defer_with_mask_error_extension(
7777
)
7878

7979
http_client = incremental_http_client_class(schema=schema)
80+
8081
response = await http_client.query(
8182
method=method,
8283
query="""
8384
query HeroNameQuery {
85+
someError
8486
character {
8587
id
8688
...NameFragment @defer
@@ -97,29 +99,38 @@ async def test_defer_with_mask_error_extension(
9799

98100
assert initial == snapshot(
99101
{
100-
"data": {"character": {"id": "1"}},
102+
"data": {"someError": None, "character": {"id": "1"}},
103+
"errors": [
104+
{
105+
"message": "Unexpected error.",
106+
"locations": [{"line": 3, "column": 13}],
107+
"path": ["someError"],
108+
}
109+
],
101110
"hasNext": True,
102-
"pending": [{"path": ["character"], "id": "0"}],
103-
# TODO: check if we need this and how to handle it
111+
"pending": [{"id": "0", "path": ["character"]}],
104112
"extensions": None,
105113
}
106114
)
107115

108116
subsequent = await stream.__anext__()
109117

118+
# TODO: not yet supported properly (the error is not masked)
110119
assert subsequent == snapshot(
111120
{
112-
"incremental": [
121+
"hasNext": False,
122+
"extensions": None,
123+
"completed": [
113124
{
114-
"data": {"name": "Thiago Bellini"},
115125
"id": "0",
116-
"path": ["character"],
117-
"label": None,
126+
"errors": [
127+
{
128+
"message": "Failed to get name",
129+
"locations": [{"line": 10, "column": 13}],
130+
"path": ["character", "name"],
131+
}
132+
],
118133
}
119134
],
120-
"completed": [{"id": "0"}],
121-
"hasNext": False,
122-
# TODO: same as above
123-
"extensions": None,
124135
}
125136
)

tests/views/schema.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ async def error(self, message: str) -> AsyncGenerator[str, None]:
110110
async def exception(self, message: str) -> str:
111111
raise ValueError(message)
112112

113+
@strawberry.field
114+
async def some_error(self) -> Optional[str]:
115+
raise ValueError("Some error")
116+
113117
@strawberry.field
114118
def teapot(self, info: strawberry.Info[Any, None]) -> str:
115119
info.context["response"].status_code = 418

0 commit comments

Comments
 (0)