Skip to content

Commit 5bee749

Browse files
committed
Added support for an unpacked TypedDict as a type annotation for a *kwargs parameter.
1 parent 9827e56 commit 5bee749

File tree

8 files changed

+188
-12
lines changed

8 files changed

+188
-12
lines changed

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
12451245
associateTypeVarsWithScope: true,
12461246
allowTypeVarTuple: paramCategory === ParameterCategory.VarArgList,
12471247
disallowRecursiveTypeAlias: true,
1248+
allowUnpackedTypedDict: paramCategory === ParameterCategory.VarArgDictionary,
12481249
allowUnpackedTuple: paramCategory === ParameterCategory.VarArgList,
12491250
});
12501251
}
@@ -1299,6 +1300,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
12991300
evaluatorFlags |= EvaluatorFlags.DisallowRecursiveTypeAliasPlaceholder;
13001301
}
13011302

1303+
if (options?.allowUnpackedTypedDict) {
1304+
evaluatorFlags |= EvaluatorFlags.AllowUnpackedTypedDict;
1305+
}
1306+
13021307
if (options?.allowUnpackedTuple) {
13031308
evaluatorFlags |= EvaluatorFlags.AllowUnpackedTupleOrTypeVarTuple;
13041309
}
@@ -13115,6 +13120,20 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1311513120
return UnknownType.create();
1311613121
}
1311713122

13123+
if ((flags & EvaluatorFlags.AllowUnpackedTypedDict) !== 0) {
13124+
if (isInstantiableClass(typeArgType) && ClassType.isTypedDictClass(typeArgType)) {
13125+
return ClassType.cloneForUnpacked(typeArgType);
13126+
}
13127+
13128+
addDiagnostic(
13129+
fileInfo.diagnosticRuleSet.reportGeneralTypeIssues,
13130+
DiagnosticRule.reportGeneralTypeIssues,
13131+
Localizer.Diagnostic.unpackExpectedTypedDict(),
13132+
errorNode
13133+
);
13134+
return UnknownType.create();
13135+
}
13136+
1311813137
addDiagnostic(
1311913138
fileInfo.diagnosticRuleSet.reportGeneralTypeIssues,
1312013139
DiagnosticRule.reportGeneralTypeIssues,
@@ -15254,6 +15273,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
1525415273
return type;
1525515274
}
1525615275

15276+
// Is this an unpacked TypedDict? If so, return it unmodified.
15277+
if (isClassInstance(type) && ClassType.isTypedDictClass(type) && type.isUnpacked) {
15278+
return type;
15279+
}
15280+
1525715281
// Wrap the type in a dict with str keys.
1525815282
const dictType = getBuiltInType(node, 'dict');
1525915283
const strType = getBuiltInObject(node, 'str');
@@ -22154,7 +22178,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2215422178
if (srcParamInfo.param.name && srcParamInfo.param.category === ParameterCategory.Simple) {
2215522179
const destParamInfo = destParamMap.get(srcParamInfo.param.name);
2215622180
const paramDiag = diag?.createAddendum();
22157-
const srcParamType = FunctionType.getEffectiveParameterType(srcType, srcParamInfo.index);
22181+
const srcParamType = srcParamInfo.type;
2215822182

2215922183
if (!destParamInfo) {
2216022184
if (destParamDetails.kwargsIndex === undefined && !srcParamInfo.param.hasDefault) {
@@ -22184,10 +22208,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2218422208
}
2218522209
}
2218622210
} else {
22187-
const destParamType = FunctionType.getEffectiveParameterType(
22188-
destType,
22189-
destParamInfo.index
22190-
);
22211+
const destParamType = destParamInfo.type;
2219122212
const specializedDestParamType = destTypeVarMap
2219222213
? applySolvedTypeVars(destParamType, destTypeVarMap)
2219322214
: destParamType;

packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ export const enum EvaluatorFlags {
128128
// the interpreter (within a source file, not a stub) still
129129
// parses the expression and generates parse errors.
130130
InterpreterParsesStringLiteral = 1 << 22,
131+
132+
// Allow Unpack annotation for TypedDict.
133+
AllowUnpackedTypedDict = 1 << 23,
131134
}
132135

133136
export interface TypeArgumentResult {
@@ -257,6 +260,7 @@ export interface AnnotationTypeOptions {
257260
allowTypeVarTuple?: boolean;
258261
allowParamSpec?: boolean;
259262
disallowRecursiveTypeAlias?: boolean;
263+
allowUnpackedTypedDict?: boolean;
260264
allowUnpackedTuple?: boolean;
261265
notParsedByInterpreter?: boolean;
262266
}

packages/pyright-internal/src/analyzer/typePrinter.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,17 @@ export function printFunctionParts(
694694
recursionTypes.length < maxTypeRecursionCount
695695
? printType(paramType, printTypeFlags, returnTypeCallback, recursionTypes)
696696
: '';
697+
697698
if (!param.isNameSynthesized) {
698699
paramString += ': ';
699700
} else if (param.category === ParameterCategory.VarArgList && !isUnpacked(paramType)) {
700701
paramString += '*';
701702
}
703+
704+
if (param.category === ParameterCategory.VarArgDictionary && isUnpacked(paramType)) {
705+
paramString += '**';
706+
}
707+
702708
paramString += paramTypeString;
703709

704710
if (isParamSpec(paramType)) {

packages/pyright-internal/src/analyzer/typeUtils.ts

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,37 @@ export function getParameterListDetails(type: FunctionType): ParameterListDetail
343343
}
344344
} else if (param.category === ParameterCategory.VarArgDictionary) {
345345
sawKeywordOnlySeparator = true;
346-
if (result.kwargsIndex === undefined) {
347-
result.kwargsIndex = result.params.length;
348-
}
349-
if (result.firstKeywordOnlyIndex === undefined) {
350-
result.firstKeywordOnlyIndex = result.params.length;
351-
}
352346

353-
addVirtualParameter(param, index);
347+
// Is this an unpacked TypedDict? If so, expand the entries.
348+
if (isClassInstance(param.type) && isUnpackedClass(param.type) && param.type.details.typedDictEntries) {
349+
if (result.firstKeywordOnlyIndex === undefined) {
350+
result.firstKeywordOnlyIndex = result.params.length;
351+
}
352+
353+
param.type.details.typedDictEntries.forEach((entry, name) => {
354+
addVirtualParameter(
355+
{
356+
category: ParameterCategory.Simple,
357+
name,
358+
type: entry.valueType,
359+
hasDeclaredType: true,
360+
hasDefault: !entry.isRequired,
361+
},
362+
index,
363+
entry.valueType
364+
);
365+
});
366+
} else {
367+
if (result.kwargsIndex === undefined) {
368+
result.kwargsIndex = result.params.length;
369+
}
370+
371+
if (result.firstKeywordOnlyIndex === undefined) {
372+
result.firstKeywordOnlyIndex = result.params.length;
373+
}
374+
375+
addVirtualParameter(param, index);
376+
}
354377
} else if (param.category === ParameterCategory.Simple) {
355378
if (param.name && !sawKeywordOnlySeparator) {
356379
result.positionParamCount++;

packages/pyright-internal/src/localization/localize.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ export namespace Localizer {
892892
new ParameterizedString<{ name1: string; name2: string }>(
893893
getRawString('Diagnostic.unpackedTypeVarTupleExpected')
894894
);
895+
export const unpackExpectedTypedDict = () => getRawString('Diagnostic.unpackExpectedTypedDict');
895896
export const unpackExpectedTypeVarTuple = () => getRawString('Diagnostic.unpackExpectedTypeVarTuple');
896897
export const unpackIllegalInComprehension = () => getRawString('Diagnostic.unpackIllegalInComprehension');
897898
export const unpackInAnnotation = () => getRawString('Diagnostic.unpackInAnnotation');

packages/pyright-internal/src/localization/package.nls.en-us.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@
456456
"unpackedSubscriptIllegal": "Unpack operator in subscript requires Python 3.11 or newer",
457457
"unpackedTypedDictArgument": "Unable to match unpacked TypedDict argument to parameters",
458458
"unpackedTypeVarTupleExpected": "Expected unpacked TypeVarTuple; use Unpack[{name1}] or *{name2}",
459+
"unpackExpectedTypedDict": "Expected TypedDict type argument for Unpack",
459460
"unpackExpectedTypeVarTuple": "Expected TypeVarTuple or Tuple as type argument for Unpack",
460461
"unpackIllegalInComprehension": "Unpack operation not allowed in comprehension",
461462
"unpackInAnnotation": "Unpack operator not allowed in type annotation",
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# This sample tests the handling of Unpack[TypedDict] when used with
2+
# a **kwargs parameter in a function signature.
3+
4+
from typing import Protocol, TypedDict
5+
from typing_extensions import NotRequired, Required, Unpack
6+
7+
8+
class TD1(TypedDict):
9+
v1: Required[int]
10+
v2: NotRequired[str]
11+
12+
13+
class TD2(TD1):
14+
v3: Required[str]
15+
16+
17+
def func1(**kwargs: Unpack[TD2]) -> None:
18+
v1 = kwargs["v1"]
19+
reveal_type(v1, expected_text="int")
20+
21+
# This should generate an error because v2 might not be present.
22+
kwargs["v2"]
23+
24+
if "v2" in kwargs:
25+
v2 = kwargs["v2"]
26+
reveal_type(v2, expected_text="str")
27+
28+
v3 = kwargs["v3"]
29+
reveal_type(v3, expected_text="str")
30+
31+
32+
reveal_type(func1, expected_text="(**kwargs: **TD2) -> None")
33+
34+
35+
def func2(v1: int, **kwargs: Unpack[TD1]) -> None:
36+
pass
37+
38+
39+
def func3():
40+
# This should generate an error because it is
41+
# missing required keyword arguments.
42+
func1()
43+
44+
func1(v1=1, v2="", v3="5")
45+
46+
td2 = TD2(v1=2, v3="4")
47+
func1(**td2)
48+
49+
# This should generate an error because v4 is not in TD2.
50+
func1(v1=1, v2="", v3="5", v4=5)
51+
52+
# This should generate an error because args are passed by position.
53+
func1(1, "", "5")
54+
55+
my_dict: dict[str, str] = {}
56+
# This should generate an error because it's an untyped dict.
57+
func1(**my_dict)
58+
59+
func1(**{"v1": 2, "v3": "4", "v4": 4})
60+
61+
# This should generate an error because v1 is already specified.
62+
func1(v1=2, **td2)
63+
64+
# This should generate an error because v1 is already specified.
65+
func2(1, **td2)
66+
67+
# This should generate an error because v1 is matched to a
68+
# named parameter and is not available for kwargs.
69+
func2(v1=1, **td2)
70+
71+
72+
class TDProtocol1(Protocol):
73+
def __call__(self, *, v1: int, v3: str) -> None:
74+
...
75+
76+
77+
class TDProtocol2(Protocol):
78+
def __call__(self, *, v1: int, v3: str, v2: str = "") -> None:
79+
...
80+
81+
82+
class TDProtocol3(Protocol):
83+
def __call__(self, *, v1: int, v2: int, v3: str) -> None:
84+
...
85+
86+
87+
class TDProtocol4(Protocol):
88+
def __call__(self, *, v1: int) -> None:
89+
...
90+
91+
92+
class TDProtocol5(Protocol):
93+
def __call__(self, v1: int, v3: str) -> None:
94+
...
95+
96+
97+
class TDProtocol6(Protocol):
98+
def __call__(self, **kwargs: Unpack[TD2]) -> None:
99+
...
100+
101+
102+
v1: TDProtocol1 = func1
103+
v2: TDProtocol2 = func1
104+
105+
# This should generate an error because v2 is the wrong type.
106+
v3: TDProtocol3 = func1
107+
108+
# This should generate an error because v3 is missing.
109+
v4: TDProtocol4 = func1
110+
111+
# This should generate an error because parameters are positional.
112+
v5: TDProtocol5 = func1
113+
114+
v6: TDProtocol6 = func1

packages/pyright-internal/src/tests/typeEvaluator1.test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,12 @@ test('Function10', () => {
673673
TestUtils.validateResults(analysisResults, 0);
674674
});
675675

676+
test('KwargsUnpack1', () => {
677+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['kwargsUnpack1.py']);
678+
679+
TestUtils.validateResults(analysisResults, 11);
680+
});
681+
676682
test('Unreachable1', () => {
677683
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['unreachable1.py']);
678684

0 commit comments

Comments
 (0)