Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions personal_python_ast_optimizer/parser/skipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
class AstNodeSkipper(ast.NodeTransformer):

__slots__ = (
"_has_imports",
"_skippable_futures",
"_within_class",
"_within_function",
Expand All @@ -50,6 +51,7 @@ def __init__(self, config: SkipConfig) -> None:
self.token_types_config: TokenTypesConfig = config.token_types_config
self.tokens_config: TokensConfig = config.tokens_config

self._has_imports: bool = False
self._within_class: bool = False
self._within_function: bool = False

Expand Down Expand Up @@ -84,7 +86,7 @@ def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None:

return wrapper

def generic_visit(self, node) -> ast.AST:
def generic_visit(self, node: ast.AST) -> ast.AST:
"""Modified version of super class's generic_visit
to extend functionality"""
for field, old_value in ast.iter_fields(node):
Expand Down Expand Up @@ -153,11 +155,8 @@ def visit_Module(self, node: ast.Module) -> ast.AST:

module: ast.Module = self.generic_visit(node) # type:ignore

if self.optimizations_config.remove_unused_imports:
names_and_attrs_finder = NamesAndAttersDetector()
names_and_attrs_finder.visit(module)

import_filter = ImportFilter(names_and_attrs_finder.names_and_attrs)
if self.optimizations_config.remove_unused_imports and self._has_imports:
import_filter = UnusedImportSkipper()
import_filter.visit(module)

self._warn_unused_skips()
Expand Down Expand Up @@ -352,15 +351,16 @@ def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST | None:

return self.generic_visit(node)

def visit_Import(self, node: ast.Import) -> ast.AST | None:
def visit_Import(self, node: ast.Import) -> ast.Import | None:
exclude_imports(node, self.tokens_config.module_imports_to_skip)

if not node.names:
return None

return self.generic_visit(node)
self._has_imports = True
return node

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None:
normalized_module_name: str = node.module or ""
if normalized_module_name in self.tokens_config.module_imports_to_skip:
return None
Expand All @@ -370,7 +370,11 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
if node.module == "__future__" and self._skippable_futures:
exclude_imports(node, self._skippable_futures)

return self.generic_visit(node) if node.names else None
if not node.names:
return None

self._has_imports = True
return node

def visit_Name(self, node: ast.Name) -> ast.AST:
"""Extends super's implementation by adding constant folding"""
Expand Down Expand Up @@ -643,35 +647,19 @@ def _ast_constants_operation(
return ast.Constant(result)


class NamesAndAttersDetector(ast.NodeVisitor):
class UnusedImportSkipper(ast.NodeTransformer):

__slots__ = ("names_and_attrs",)

def __init__(self) -> None:
self.names_and_attrs: set[str] = set()

def visit_Name(self, node: ast.Name) -> ast.Name:
self.names_and_attrs.add(node.id)
return node

def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
self.names_and_attrs.add(node.attr)
return self.generic_visit(node)


class ImportFilter(ast.NodeTransformer):

__slots__ = ("imports_to_keep",)

def __init__(self, imports_to_keep: set[str]) -> None:
self.imports_to_keep: set[str] = imports_to_keep

def generic_visit(self, node):
for _, old_value in ast.iter_fields(node):
def generic_visit(self, node: ast.AST) -> ast.AST:
for field, old_value in ast.iter_fields(node):
if isinstance(old_value, list):
new_values = []
ast_removed: bool = False
for value in old_value:
for value in reversed(old_value):
if isinstance(value, ast.AST):
value = self.visit(value)
if value is None:
Expand All @@ -686,17 +674,32 @@ def generic_visit(self, node):
if not isinstance(node, ast.Module) and not new_values and ast_removed:
new_values = [ast.Pass()]

old_value[:] = new_values
old_value[:] = reversed(new_values)

elif isinstance(old_value, ast.AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)

return node

def visit_Import(self, node: ast.Import) -> ast.Import | None:
filter_imports(node, self.imports_to_keep)
filter_imports(node, self.names_and_attrs)

return node if node.names else None

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None:
if node.module != "__future__":
filter_imports(node, self.imports_to_keep)
filter_imports(node, self.names_and_attrs)

return node if node.names else None

def visit_Name(self, node: ast.Name) -> ast.Name:
self.names_and_attrs.add(node.id)
return node

def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
self.names_and_attrs.add(node.attr)
return self.generic_visit(node)
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5.1.2
5.2.0