Skip to content

Commit 42157ba

Browse files
authored
stubgen: preserve enum value initialisers (#17125)
See python/typing-council#11
1 parent 2892ed4 commit 42157ba

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mypy/stubgen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def __init__(
453453
self.analyzed = analyzed
454454
# Short names of methods defined in the body of the current class
455455
self.method_names: set[str] = set()
456+
self.processing_enum = False
456457
self.processing_dataclass = False
457458

458459
def visit_mypy_file(self, o: MypyFile) -> None:
@@ -727,6 +728,8 @@ def visit_class_def(self, o: ClassDef) -> None:
727728
if base_types:
728729
for base in base_types:
729730
self.import_tracker.require_name(base)
731+
if self.analyzed and o.info.is_enum:
732+
self.processing_enum = True
730733
if isinstance(o.metaclass, (NameExpr, MemberExpr)):
731734
meta = o.metaclass.accept(AliasPrinter(self))
732735
base_types.append("metaclass=" + meta)
@@ -756,6 +759,7 @@ def visit_class_def(self, o: ClassDef) -> None:
756759
self._state = CLASS
757760
self.method_names = set()
758761
self.processing_dataclass = False
762+
self.processing_enum = False
759763
self._current_class = None
760764

761765
def get_base_types(self, cdef: ClassDef) -> list[str]:
@@ -1153,6 +1157,9 @@ def get_init(
11531157
# Final without type argument is invalid in stubs.
11541158
final_arg = self.get_str_type_of_node(rvalue)
11551159
typename += f"[{final_arg}]"
1160+
elif self.processing_enum:
1161+
initializer, _ = self.get_str_default_of_node(rvalue)
1162+
return f"{self._indent}{lvalue} = {initializer}\n"
11561163
elif self.processing_dataclass:
11571164
# attribute without annotation is not a dataclass field, don't add annotation.
11581165
return f"{self._indent}{lvalue} = ...\n"

test-data/unit/stubgen.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4342,3 +4342,27 @@ alias = tuple[()]
43424342
def f(x: tuple[()]): ...
43434343

43444344
class C(tuple[()]): ...
4345+
4346+
[case testPreserveEnumValue_semanal]
4347+
from enum import Enum
4348+
4349+
class Foo(Enum):
4350+
A = 1
4351+
B = 2
4352+
C = 3
4353+
4354+
class Bar(Enum):
4355+
A = object()
4356+
B = "a" + "b"
4357+
4358+
[out]
4359+
from enum import Enum
4360+
4361+
class Foo(Enum):
4362+
A = 1
4363+
B = 2
4364+
C = 3
4365+
4366+
class Bar(Enum):
4367+
A = ...
4368+
B = ...

0 commit comments

Comments
 (0)