diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index bc641ea..d510f31 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -334,21 +334,23 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None: if self._within_class and get_node_name(node.target) == "__slots__": remove_duplicate_slots(node, self.warn_unusual_code) - parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore - if self.token_types_config.skip_type_hints: + node.annotation = None # type: ignore + parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore + if ( not parsed_node.value and self._within_class and not self._within_function ): parsed_node.annotation = ast.Name("int") + return parsed_node elif parsed_node.value is None: return None else: return ast.Assign([parsed_node.target], parsed_node.value) - - return parsed_node + else: + return self.generic_visit(node) def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST | None: if get_node_name(node.target) in self.tokens_config.variables_to_skip: diff --git a/tests/parser/test_imports.py b/tests/parser/test_imports.py index fa2690b..26d1a0c 100644 --- a/tests/parser/test_imports.py +++ b/tests/parser/test_imports.py @@ -75,7 +75,7 @@ def test_import_star(): ) -def test_remove_unused_imports(): +def test_remove_unused_import(): before_and_after = BeforeAndAfter( """ @@ -87,3 +87,39 @@ def test_remove_unused_imports(): "if a==b:pass\nprint(a)", ) run_minifier_and_assert_correct(before_and_after) + + +def test_remove_unused_import_type_annotation(): + + before_and_after = BeforeAndAfter( + """ +import foo + +a: foo = bar() + +def asdf(a: foo) -> foo: + return a""", + """ +a=bar() +def asdf(a):return a +""".strip(), + ) + run_minifier_and_assert_correct(before_and_after) + + +def test_remove_unused_import_from_type_annotation(): + + before_and_after = BeforeAndAfter( + """ +from .typing import foo + +a: foo | None = bar() + +def asdf(a: foo) -> foo: + return a""", + """ +a=bar() +def asdf(a):return a +""".strip(), + ) + run_minifier_and_assert_correct(before_and_after) diff --git a/version.txt b/version.txt index 831446c..ac14c3d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -5.1.0 +5.1.1