Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion personal_python_ast_optimizer/parser/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ def run_minify_parser(
module: ast.Module = ast.parse(source)

if skip_config is not None:
module = AstNodeSkipper(skip_config).visit(module)
AstNodeSkipper(skip_config).visit(module)

return parser.visit(module)
72 changes: 36 additions & 36 deletions personal_python_ast_optimizer/parser/skipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from personal_python_ast_optimizer.parser.utils import (
exclude_imports,
filter_imports,
first_occurrence_of_type,
get_node_name,
is_overload_function,
Expand All @@ -25,10 +26,9 @@
)


class AstNodeSkipper(ast.NodeVisitor):
class AstNodeSkipper(ast.NodeTransformer):

__slots__ = (
"_possibly_unused_imports",
"_skippable_futures",
"_within_class",
"_within_function",
Expand Down Expand Up @@ -62,8 +62,6 @@ def __init__(self, config: SkipConfig) -> None:
if self.token_types_config.skip_type_hints:
self._skippable_futures.append("annotations")

self._possibly_unused_imports: set[str] = set()

@staticmethod
def _within_class_node(function):
def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None:
Expand Down Expand Up @@ -150,16 +148,17 @@ def visit_Module(self, node: ast.Module) -> ast.AST:
if not self._has_code_to_skip():
return node

self._possibly_unused_imports = set()

if self.token_types_config.skip_dangling_expressions:
skip_dangling_expressions(node)

module: ast.Module = self.generic_visit(node) # type:ignore

if self._possibly_unused_imports:
import_skipper = ImportSkipper(self._possibly_unused_imports)
import_skipper.generic_visit(module)
if self.optimizations_config.remove_unused_imports:
names_and_attrs_finder = NamesAndAttersDetector()
names_and_attrs_finder.visit(module)

import_filter = ImportFilter(names_and_attrs_finder.names_and_attrs)
import_filter.visit(module)

self._warn_unused_skips()
return module
Expand Down Expand Up @@ -243,9 +242,6 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None:
):
return self._get_enum_value_as_AST(node.value.attr, node.attr)

if node.attr in self._possibly_unused_imports:
self._possibly_unused_imports.remove(node.attr)

return self.generic_visit(node)

def _get_enum_value_as_AST(self, class_name: str, value_name: str) -> ast.Constant:
Expand Down Expand Up @@ -334,10 +330,9 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None:
if self._within_class and get_node_name(node.target) == "__slots__":
remove_duplicate_slots(node, self.warn_unusual_code)

if self.token_types_config.skip_type_hints:
node.annotation = None # type: ignore
parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore
parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore

if self.token_types_config.skip_type_hints:
if (
not parsed_node.value
and self._within_class
Expand All @@ -349,8 +344,8 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None:
return None
else:
return ast.Assign([parsed_node.target], parsed_node.value)
else:
return self.generic_visit(node)

return parsed_node

def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST | None:
if get_node_name(node.target) in self.tokens_config.variables_to_skip:
Expand All @@ -364,9 +359,6 @@ def visit_Import(self, node: ast.Import) -> ast.AST | None:
if not node.names:
return None

if self.optimizations_config.remove_unused_imports:
self._possibly_unused_imports.update(n.asname or n.name for n in node.names)

return self.generic_visit(node)

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
Expand All @@ -378,12 +370,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:

if node.module == "__future__" and self._skippable_futures:
exclude_imports(node, self._skippable_futures)
elif (
node.module != "__future__"
and node.names
and self.optimizations_config.remove_unused_imports
):
self._possibly_unused_imports.update(n.asname or n.name for n in node.names)

return self.generic_visit(node) if node.names else None

Expand All @@ -393,11 +379,8 @@ def visit_Name(self, node: ast.Name) -> ast.AST:
constant_value = self.optimizations_config.vars_to_fold[node.id]

return ast.Constant(constant_value)
else:
if node.id in self._possibly_unused_imports:
self._possibly_unused_imports.remove(node.id)

return self.generic_visit(node)
return node

def visit_Dict(self, node: ast.Dict) -> ast.AST:
if self.tokens_config.dict_keys_to_skip:
Expand Down Expand Up @@ -661,12 +644,28 @@ def _ast_constants_operation(
return ast.Constant(result)


class ImportSkipper(ast.NodeVisitor):
class NamesAndAttersDetector(ast.NodeVisitor):

__slots__ = ("names_and_attrs",)

def __init__(self) -> None:
self.names_and_attrs: set[str] = set()

def visit_Name(self, node: ast.Name) -> ast.Name:
self.names_and_attrs.add(node.id)
return node

def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
self.names_and_attrs.add(node.attr)
return self.generic_visit(node)


class ImportFilter(ast.NodeTransformer):

__slots__ = ("unused_imports",)
__slots__ = ("imports_to_keep",)

def __init__(self, unused_imports: set[str]) -> None:
self.unused_imports = unused_imports
def __init__(self, imports_to_keep: set[str]) -> None:
self.imports_to_keep: set[str] = imports_to_keep

def generic_visit(self, node):
for _, old_value in ast.iter_fields(node):
Expand All @@ -693,11 +692,12 @@ def generic_visit(self, node):
return node

def visit_Import(self, node: ast.Import) -> ast.Import | None:
exclude_imports(node, self.unused_imports)
filter_imports(node, self.imports_to_keep)

return node if node.names else None

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None:
exclude_imports(node, self.unused_imports)
if node.module != "__future__":
filter_imports(node, self.imports_to_keep)

return node if node.names else None
6 changes: 6 additions & 0 deletions personal_python_ast_optimizer/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ def exclude_imports(node: ast.Import | ast.ImportFrom, exlcudes: Iterable[str])
]


def filter_imports(node: ast.Import | ast.ImportFrom, filter: Iterable[str]) -> None:
node.names = [
alias for alias in node.names if (alias.asname or alias.name) in filter
]


def get_node_name(node: ast.AST | None) -> str:
"""Gets id or attr which both can represent var names"""
if isinstance(node, ast.Call):
Expand Down
41 changes: 23 additions & 18 deletions tests/parser/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,17 @@ def test_import_star():
)


def test_remove_unused_import():

before_and_after = BeforeAndAfter(
_unused_import_test_cases: list[BeforeAndAfter] = [
BeforeAndAfter(
"""
if a == b:
import foo
import bar

print(a)""",
"if a==b:pass\nprint(a)",
)
run_minifier_and_assert_correct(before_and_after)


def test_remove_unused_import_type_annotation():

before_and_after = BeforeAndAfter(
),
BeforeAndAfter(
"""
import foo

Expand All @@ -103,13 +97,8 @@ def asdf(a: foo) -> foo:
a=bar()
def asdf(a):return a
""".strip(),
)
run_minifier_and_assert_correct(before_and_after)


def test_remove_unused_import_from_type_annotation():

before_and_after = BeforeAndAfter(
),
BeforeAndAfter(
"""
from .typing import foo

Expand All @@ -121,5 +110,21 @@ def asdf(a: foo) -> foo:
a=bar()
def asdf(a):return a
""".strip(),
)
),
BeforeAndAfter(
"""
from foo import bar as bar2

if False:
bar2()

bar()
""".strip(),
"bar()",
),
]


@pytest.mark.parametrize("before_and_after", _unused_import_test_cases)
def test_remove_unused_import(before_and_after: BeforeAndAfter):
run_minifier_and_assert_correct(before_and_after)
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5.1.1
5.1.2