@@ -453,6 +453,7 @@ def __init__(
453
453
self .analyzed = analyzed
454
454
# Short names of methods defined in the body of the current class
455
455
self .method_names : set [str ] = set ()
456
+ self .processing_enum = False
456
457
self .processing_dataclass = False
457
458
458
459
def visit_mypy_file (self , o : MypyFile ) -> None :
@@ -727,6 +728,8 @@ def visit_class_def(self, o: ClassDef) -> None:
727
728
if base_types :
728
729
for base in base_types :
729
730
self .import_tracker .require_name (base )
731
+ if self .analyzed and o .info .is_enum :
732
+ self .processing_enum = True
730
733
if isinstance (o .metaclass , (NameExpr , MemberExpr )):
731
734
meta = o .metaclass .accept (AliasPrinter (self ))
732
735
base_types .append ("metaclass=" + meta )
@@ -756,6 +759,7 @@ def visit_class_def(self, o: ClassDef) -> None:
756
759
self ._state = CLASS
757
760
self .method_names = set ()
758
761
self .processing_dataclass = False
762
+ self .processing_enum = False
759
763
self ._current_class = None
760
764
761
765
def get_base_types (self , cdef : ClassDef ) -> list [str ]:
@@ -1153,6 +1157,9 @@ def get_init(
1153
1157
# Final without type argument is invalid in stubs.
1154
1158
final_arg = self .get_str_type_of_node (rvalue )
1155
1159
typename += f"[{ final_arg } ]"
1160
+ elif self .processing_enum :
1161
+ initializer , _ = self .get_str_default_of_node (rvalue )
1162
+ return f"{ self ._indent } { lvalue } = { initializer } \n "
1156
1163
elif self .processing_dataclass :
1157
1164
# attribute without annotation is not a dataclass field, don't add annotation.
1158
1165
return f"{ self ._indent } { lvalue } = ...\n "
0 commit comments