Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions personal_python_ast_optimizer/parser/minifier.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import ast
from ast import _Unparser # type: ignore
from typing import Iterable, Iterator, Literal

from personal_python_ast_optimizer.parser.utils import node_inlineable
from personal_python_ast_optimizer.python_info import (
chars_that_dont_need_whitespace,
comparison_and_conjunctions,
operators_and_separators,
)


class MinifyUnparser(_Unparser):
class MinifyUnparser(ast._Unparser): # type: ignore

__slots__ = ("can_write_body_in_one_line", "previous_node_in_body")

Expand All @@ -37,7 +35,7 @@ def write(self, *text: str) -> None:
"""Write text, with some mapping replacements"""
text = tuple(self._yield_updated_text(text))

if len(text) == 0:
if not text:
return

first_letter_to_write: str = text[0][:1]
Expand All @@ -60,10 +58,6 @@ def _yield_updated_text(self, text_iter: Iterable[str]) -> Iterator[str]:
elif text:
yield text

def maybe_newline(self) -> None:
if self._source and self._source[-1] != "\n":
self.write("\n")

def visit_node(
self,
node: ast.AST,
Expand All @@ -82,7 +76,8 @@ def traverse(self, node: list[ast.stmt] | ast.AST) -> None:
if isinstance(node, list):
last_visited_node: ast.stmt | None = None
can_write_body_in_one_line = (
all(node_inlineable(sub_node) for sub_node in node) or len(node) == 1
all(self._node_inlineable(sub_node) for sub_node in node)
or len(node) == 1
)

for sub_node in node:
Expand All @@ -104,9 +99,17 @@ def visit_Assert(self, node: ast.Assert) -> None:
self.fill("assert ", splitter=self._get_line_splitter())
self.traverse(node.test)
if node.msg:
self.write(", ")
self.write(",")
self.traverse(node.msg)

def visit_Global(self, node: ast.Global) -> None:
self.fill("global ", splitter=self._get_line_splitter())
self.interleave(lambda: self.write(","), self.write, node.names)

def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
self.fill("nonlocal ", splitter=self._get_line_splitter())
self.interleave(lambda: self.write(","), self.write, node.names)

def visit_Delete(self, node: ast.Delete) -> None:
self.fill("del ", splitter=self._get_line_splitter())
self._write_comma_delimitated_body(node.targets)
Expand Down Expand Up @@ -155,10 +158,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"(", ")", not node.simple and isinstance(node.target, ast.Name)
):
self.traverse(node.target)
self.write(": ")
self.write(":")
self.traverse(node.annotation)
if node.value:
self.write(" = ")
self.write("=")
self.traverse(node.value)

def visit_Assign(self, node: ast.Assign) -> None:
Expand All @@ -175,14 +178,64 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None:
self.write(self.binop[node.op.__class__.__name__] + "=")
self.traverse(node.value)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
self._write_decorators(node)
self.fill("class " + node.name)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit_if("(", ")", condition=node.bases or node.keywords):
comma = False
for base in node.bases:
if comma:
self.write(",")
else:
comma = True
self.traverse(base)
for kw in node.keywords:
if comma:
self.write(",")
else:
comma = True
self.traverse(kw)

with self.block():
self._write_docstring_and_traverse_body(node)

def _function_helper(
self,
node: ast.FunctionDef | ast.AsyncFunctionDef,
fill_suffix: Literal["def", "async def"],
) -> None:
self._write_decorators(node)
def_str = fill_suffix + " " + node.name
self.fill(def_str)
if hasattr(node, "type_params"):
self._type_params_helper(node.type_params)
with self.delimit("(", ")"):
self.traverse(node.args)
if node.returns:
self.write("->")
self.traverse(node.returns)
with self.block(extra=self.get_type_comment(node)):
self._write_docstring_and_traverse_body(node)

def _write_decorators(
self, node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef
) -> None:
for deco in node.decorator_list:
self.fill("@")
self.traverse(deco)

def _last_char_is(self, char_to_check: str) -> bool:
return len(self._source) > 0 and self._source[-1][-1:] == char_to_check
return bool(self._source) and self._source[-1][-1:] == char_to_check

def _get_space_before_write(self) -> str:
if not self._source:
return ""
most_recent_token: str = self._source[-1]
return "" if most_recent_token[-1:] in chars_that_dont_need_whitespace else " "
return (
""
if not self._source
or self._source[-1][-1:] in chars_that_dont_need_whitespace
else " "
)

def _get_line_splitter(self) -> Literal["", "\n", ";"]:
"""Get character that starts the next line of code with the shortest
Expand All @@ -197,7 +250,7 @@ def _get_line_splitter(self) -> Literal["", "\n", ";"]:
if (
self._indent > 0
and self.previous_node_in_body is not None
and node_inlineable(self.previous_node_in_body)
and self._node_inlineable(self.previous_node_in_body)
):
return ";"

Expand All @@ -208,3 +261,23 @@ def _write_comma_delimitated_body(
) -> None:
"""Writes ast expr objects with comma delimitation"""
self.interleave(lambda: self.write(","), self.traverse, body)

@staticmethod
def _node_inlineable(node: ast.AST) -> bool:
return node.__class__.__name__ in [
"Assert",
"AnnAssign",
"Assign",
"AugAssign",
"Break",
"Continue",
"Delete",
"Expr",
"Global",
"Import",
"ImportFrom",
"Nonlocal",
"Pass",
"Raise",
"Return",
]
57 changes: 32 additions & 25 deletions personal_python_ast_optimizer/parser/skipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,21 @@ def __init__(self, config: SkipConfig) -> None:

@staticmethod
def _within_class_node(function):
def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None:
def wrapper(self: "AstNodeSkipper", *args) -> ast.AST | None:
self._within_class = True
try:
return function(self, *args, **kwargs)
return function(self, *args)
finally:
self._within_class = False

return wrapper

@staticmethod
def _within_function_node(function):
def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None:
def wrapper(self: "AstNodeSkipper", *args) -> ast.AST | None:
self._within_function = True
try:
return function(self, *args, **kwargs)
return function(self, *args)
finally:
self._within_function = False

Expand Down Expand Up @@ -108,7 +108,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
and not new_values
and field == "body"
):
new_values = [ast.Pass()]
new_values.append(ast.Pass())

old_value[:] = new_values

Expand Down Expand Up @@ -153,14 +153,14 @@ def visit_Module(self, node: ast.Module) -> ast.AST:
if self.token_types_config.skip_dangling_expressions:
skip_dangling_expressions(node)

module: ast.Module = self.generic_visit(node) # type:ignore
self.generic_visit(node)

if self.optimizations_config.remove_unused_imports and self._has_imports:
import_filter = UnusedImportSkipper()
import_filter.visit(module)
import_filter.visit(node)

self._warn_unused_skips()
return module
return node

@_within_class_node
def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None:
Expand All @@ -181,23 +181,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None:

return self.generic_visit(node)

@_within_function_node
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST | None:
return self._handle_function_node(node)

@_within_function_node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST | None:
return self._handle_function_node(node)

def _should_skip_function(
self, node: ast.FunctionDef | ast.AsyncFunctionDef
) -> bool:
"""If a function node should be skipped"""
return node.name in self.tokens_config.functions_to_skip or (
self.token_types_config.skip_overload_functions
and is_overload_function(node)
)

@_within_function_node
def _handle_function_node(
self, node: ast.FunctionDef | ast.AsyncFunctionDef
) -> ast.AST | None:
Expand All @@ -218,10 +208,19 @@ def _handle_function_node(
if isinstance(last_body_node, ast.Return) and (
is_return_none(last_body_node) or last_body_node.value is None
):
node.body = node.body[:-1]
node.body.pop()

return self.generic_visit(node)

def _should_skip_function(
self, node: ast.FunctionDef | ast.AsyncFunctionDef
) -> bool:
"""Determines if a function node should be skipped."""
return node.name in self.tokens_config.functions_to_skip or (
self.token_types_config.skip_overload_functions
and is_overload_function(node)
)

def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None:
if isinstance(node.value, ast.Name):
if node.attr in self.optimizations_config.enums_to_fold.get(
Expand Down Expand Up @@ -376,6 +375,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None:
self._has_imports = True
return node

def visit_alias(self, node: ast.alias) -> ast.alias:
return node

def visit_Name(self, node: ast.Name) -> ast.AST:
"""Extends super's implementation by adding constant folding"""
if node.id in self.optimizations_config.vars_to_fold:
Expand Down Expand Up @@ -423,6 +425,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.AST | None:
def visit_Return(self, node: ast.Return) -> ast.AST:
if is_return_none(node):
node.value = None
return node

return self.generic_visit(node)

Expand All @@ -431,6 +434,12 @@ def visit_Pass(self, node: ast.Pass) -> None:
are populated with a Pass node."""
return None # This could be toggleable

def visit_Break(self, node: ast.Break) -> ast.Break:
return node

def visit_Continue(self, node: ast.Continue) -> ast.Continue:
return node

def visit_Call(self, node: ast.Call) -> ast.AST | None:
if (
self.optimizations_config.assume_this_machine
Expand Down Expand Up @@ -503,6 +512,8 @@ def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
def visit_arg(self, node: ast.arg) -> ast.AST:
if self.token_types_config.skip_type_hints:
node.annotation = None
return node

return self.generic_visit(node)

def visit_arguments(self, node: ast.arguments) -> ast.AST:
Expand Down Expand Up @@ -565,7 +576,6 @@ def _use_version_optimization(self, min_version: tuple[int, int]) -> bool:
def _has_code_to_skip(self) -> bool:
return (
self.target_python_version is not None
or len(self.optimizations_config.vars_to_fold) > 0
or self.optimizations_config.has_code_to_skip()
or self.tokens_config.has_code_to_skip()
or self.token_types_config.has_code_to_skip()
Expand Down Expand Up @@ -665,14 +675,11 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
if value is None:
ast_removed = True
continue
elif not isinstance(value, ast.AST):
new_values.extend(value)
continue

new_values.append(value)

if not isinstance(node, ast.Module) and not new_values and ast_removed:
new_values = [ast.Pass()]
new_values.append(ast.Pass())

old_value[:] = reversed(new_values)

Expand Down
18 changes: 0 additions & 18 deletions personal_python_ast_optimizer/parser/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,6 @@ def is_return_none(node: ast.Return) -> bool:
return isinstance(node.value, ast.Constant) and node.value.value is None


def node_inlineable(node: ast.AST) -> bool:
return node.__class__.__name__ in [
"Assert",
"AnnAssign",
"Assign",
"AugAssign",
"Break",
"Continue",
"Delete",
"Expr",
"Import",
"ImportFrom",
"Pass",
"Raise",
"Return",
]


def skip_dangling_expressions(
node: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
) -> None:
Expand Down
Loading