From 73d0016629a4f658e721d558f1c1560a6356e775 Mon Sep 17 00:00:00 2001 From: jbjd Date: Wed, 17 Dec 2025 19:52:26 -0600 Subject: [PATCH] Better and more accurate unused import detection --- personal_python_ast_optimizer/parser/run.py | 2 +- .../parser/skipper.py | 72 +++++++++---------- personal_python_ast_optimizer/parser/utils.py | 6 ++ tests/parser/test_imports.py | 41 ++++++----- version.txt | 2 +- 5 files changed, 67 insertions(+), 56 deletions(-) diff --git a/personal_python_ast_optimizer/parser/run.py b/personal_python_ast_optimizer/parser/run.py index 047e62b..fe39cd1 100644 --- a/personal_python_ast_optimizer/parser/run.py +++ b/personal_python_ast_optimizer/parser/run.py @@ -11,6 +11,6 @@ def run_minify_parser( module: ast.Module = ast.parse(source) if skip_config is not None: - module = AstNodeSkipper(skip_config).visit(module) + AstNodeSkipper(skip_config).visit(module) return parser.visit(module) diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index d510f31..1b29b8b 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -14,6 +14,7 @@ ) from personal_python_ast_optimizer.parser.utils import ( exclude_imports, + filter_imports, first_occurrence_of_type, get_node_name, is_overload_function, @@ -25,10 +26,9 @@ ) -class AstNodeSkipper(ast.NodeVisitor): +class AstNodeSkipper(ast.NodeTransformer): __slots__ = ( - "_possibly_unused_imports", "_skippable_futures", "_within_class", "_within_function", @@ -62,8 +62,6 @@ def __init__(self, config: SkipConfig) -> None: if self.token_types_config.skip_type_hints: self._skippable_futures.append("annotations") - self._possibly_unused_imports: set[str] = set() - @staticmethod def _within_class_node(function): def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None: @@ -150,16 +148,17 @@ def visit_Module(self, node: ast.Module) -> ast.AST: if not self._has_code_to_skip(): return node - self._possibly_unused_imports = set() - if self.token_types_config.skip_dangling_expressions: skip_dangling_expressions(node) module: ast.Module = self.generic_visit(node) # type:ignore - if self._possibly_unused_imports: - import_skipper = ImportSkipper(self._possibly_unused_imports) - import_skipper.generic_visit(module) + if self.optimizations_config.remove_unused_imports: + names_and_attrs_finder = NamesAndAttersDetector() + names_and_attrs_finder.visit(module) + + import_filter = ImportFilter(names_and_attrs_finder.names_and_attrs) + import_filter.visit(module) self._warn_unused_skips() return module @@ -243,9 +242,6 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None: ): return self._get_enum_value_as_AST(node.value.attr, node.attr) - if node.attr in self._possibly_unused_imports: - self._possibly_unused_imports.remove(node.attr) - return self.generic_visit(node) def _get_enum_value_as_AST(self, class_name: str, value_name: str) -> ast.Constant: @@ -334,10 +330,9 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None: if self._within_class and get_node_name(node.target) == "__slots__": remove_duplicate_slots(node, self.warn_unusual_code) - if self.token_types_config.skip_type_hints: - node.annotation = None # type: ignore - parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore + parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore + if self.token_types_config.skip_type_hints: if ( not parsed_node.value and self._within_class @@ -349,8 +344,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None: return None else: return ast.Assign([parsed_node.target], parsed_node.value) - else: - return self.generic_visit(node) + + return parsed_node def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST | None: if get_node_name(node.target) in self.tokens_config.variables_to_skip: @@ -364,9 +359,6 @@ def visit_Import(self, node: ast.Import) -> ast.AST | None: if not node.names: return None - if self.optimizations_config.remove_unused_imports: - self._possibly_unused_imports.update(n.asname or n.name for n in node.names) - return self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None: @@ -378,12 +370,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None: if node.module == "__future__" and self._skippable_futures: exclude_imports(node, self._skippable_futures) - elif ( - node.module != "__future__" - and node.names - and self.optimizations_config.remove_unused_imports - ): - self._possibly_unused_imports.update(n.asname or n.name for n in node.names) return self.generic_visit(node) if node.names else None @@ -393,11 +379,8 @@ def visit_Name(self, node: ast.Name) -> ast.AST: constant_value = self.optimizations_config.vars_to_fold[node.id] return ast.Constant(constant_value) - else: - if node.id in self._possibly_unused_imports: - self._possibly_unused_imports.remove(node.id) - return self.generic_visit(node) + return node def visit_Dict(self, node: ast.Dict) -> ast.AST: if self.tokens_config.dict_keys_to_skip: @@ -661,12 +644,28 @@ def _ast_constants_operation( return ast.Constant(result) -class ImportSkipper(ast.NodeVisitor): +class NamesAndAttersDetector(ast.NodeVisitor): + + __slots__ = ("names_and_attrs",) + + def __init__(self) -> None: + self.names_and_attrs: set[str] = set() + + def visit_Name(self, node: ast.Name) -> ast.Name: + self.names_and_attrs.add(node.id) + return node + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + self.names_and_attrs.add(node.attr) + return self.generic_visit(node) + + +class ImportFilter(ast.NodeTransformer): - __slots__ = ("unused_imports",) + __slots__ = ("imports_to_keep",) - def __init__(self, unused_imports: set[str]) -> None: - self.unused_imports = unused_imports + def __init__(self, imports_to_keep: set[str]) -> None: + self.imports_to_keep: set[str] = imports_to_keep def generic_visit(self, node): for _, old_value in ast.iter_fields(node): @@ -693,11 +692,12 @@ def generic_visit(self, node): return node def visit_Import(self, node: ast.Import) -> ast.Import | None: - exclude_imports(node, self.unused_imports) + filter_imports(node, self.imports_to_keep) return node if node.names else None def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None: - exclude_imports(node, self.unused_imports) + if node.module != "__future__": + filter_imports(node, self.imports_to_keep) return node if node.names else None diff --git a/personal_python_ast_optimizer/parser/utils.py b/personal_python_ast_optimizer/parser/utils.py index 9b1a5b8..728b708 100644 --- a/personal_python_ast_optimizer/parser/utils.py +++ b/personal_python_ast_optimizer/parser/utils.py @@ -11,6 +11,12 @@ def exclude_imports(node: ast.Import | ast.ImportFrom, exlcudes: Iterable[str]) ] +def filter_imports(node: ast.Import | ast.ImportFrom, filter: Iterable[str]) -> None: + node.names = [ + alias for alias in node.names if (alias.asname or alias.name) in filter + ] + + def get_node_name(node: ast.AST | None) -> str: """Gets id or attr which both can represent var names""" if isinstance(node, ast.Call): diff --git a/tests/parser/test_imports.py b/tests/parser/test_imports.py index 26d1a0c..9e572dd 100644 --- a/tests/parser/test_imports.py +++ b/tests/parser/test_imports.py @@ -75,9 +75,8 @@ def test_import_star(): ) -def test_remove_unused_import(): - - before_and_after = BeforeAndAfter( +_unused_import_test_cases: list[BeforeAndAfter] = [ + BeforeAndAfter( """ if a == b: import foo @@ -85,13 +84,8 @@ def test_remove_unused_import(): print(a)""", "if a==b:pass\nprint(a)", - ) - run_minifier_and_assert_correct(before_and_after) - - -def test_remove_unused_import_type_annotation(): - - before_and_after = BeforeAndAfter( + ), + BeforeAndAfter( """ import foo @@ -103,13 +97,8 @@ def asdf(a: foo) -> foo: a=bar() def asdf(a):return a """.strip(), - ) - run_minifier_and_assert_correct(before_and_after) - - -def test_remove_unused_import_from_type_annotation(): - - before_and_after = BeforeAndAfter( + ), + BeforeAndAfter( """ from .typing import foo @@ -121,5 +110,21 @@ def asdf(a: foo) -> foo: a=bar() def asdf(a):return a """.strip(), - ) + ), + BeforeAndAfter( + """ +from foo import bar as bar2 + +if False: + bar2() + +bar() +""".strip(), + "bar()", + ), +] + + +@pytest.mark.parametrize("before_and_after", _unused_import_test_cases) +def test_remove_unused_import(before_and_after: BeforeAndAfter): run_minifier_and_assert_correct(before_and_after) diff --git a/version.txt b/version.txt index ac14c3d..61fcc87 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -5.1.1 +5.1.2