diff --git a/personal_python_ast_optimizer/parser/minifier.py b/personal_python_ast_optimizer/parser/minifier.py index 119a55f..ee430b0 100644 --- a/personal_python_ast_optimizer/parser/minifier.py +++ b/personal_python_ast_optimizer/parser/minifier.py @@ -1,8 +1,6 @@ import ast -from ast import _Unparser # type: ignore from typing import Iterable, Iterator, Literal -from personal_python_ast_optimizer.parser.utils import node_inlineable from personal_python_ast_optimizer.python_info import ( chars_that_dont_need_whitespace, comparison_and_conjunctions, @@ -10,7 +8,7 @@ ) -class MinifyUnparser(_Unparser): +class MinifyUnparser(ast._Unparser): # type: ignore __slots__ = ("can_write_body_in_one_line", "previous_node_in_body") @@ -37,7 +35,7 @@ def write(self, *text: str) -> None: """Write text, with some mapping replacements""" text = tuple(self._yield_updated_text(text)) - if len(text) == 0: + if not text: return first_letter_to_write: str = text[0][:1] @@ -60,10 +58,6 @@ def _yield_updated_text(self, text_iter: Iterable[str]) -> Iterator[str]: elif text: yield text - def maybe_newline(self) -> None: - if self._source and self._source[-1] != "\n": - self.write("\n") - def visit_node( self, node: ast.AST, @@ -82,7 +76,8 @@ def traverse(self, node: list[ast.stmt] | ast.AST) -> None: if isinstance(node, list): last_visited_node: ast.stmt | None = None can_write_body_in_one_line = ( - all(node_inlineable(sub_node) for sub_node in node) or len(node) == 1 + all(self._node_inlineable(sub_node) for sub_node in node) + or len(node) == 1 ) for sub_node in node: @@ -104,9 +99,17 @@ def visit_Assert(self, node: ast.Assert) -> None: self.fill("assert ", splitter=self._get_line_splitter()) self.traverse(node.test) if node.msg: - self.write(", ") + self.write(",") self.traverse(node.msg) + def visit_Global(self, node: ast.Global) -> None: + self.fill("global ", splitter=self._get_line_splitter()) + self.interleave(lambda: self.write(","), self.write, node.names) + + def visit_Nonlocal(self, node: ast.Nonlocal) -> None: + self.fill("nonlocal ", splitter=self._get_line_splitter()) + self.interleave(lambda: self.write(","), self.write, node.names) + def visit_Delete(self, node: ast.Delete) -> None: self.fill("del ", splitter=self._get_line_splitter()) self._write_comma_delimitated_body(node.targets) @@ -155,10 +158,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: "(", ")", not node.simple and isinstance(node.target, ast.Name) ): self.traverse(node.target) - self.write(": ") + self.write(":") self.traverse(node.annotation) if node.value: - self.write(" = ") + self.write("=") self.traverse(node.value) def visit_Assign(self, node: ast.Assign) -> None: @@ -175,14 +178,64 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None: self.write(self.binop[node.op.__class__.__name__] + "=") self.traverse(node.value) + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self._write_decorators(node) + self.fill("class " + node.name) + if hasattr(node, "type_params"): + self._type_params_helper(node.type_params) + with self.delimit_if("(", ")", condition=node.bases or node.keywords): + comma = False + for base in node.bases: + if comma: + self.write(",") + else: + comma = True + self.traverse(base) + for kw in node.keywords: + if comma: + self.write(",") + else: + comma = True + self.traverse(kw) + + with self.block(): + self._write_docstring_and_traverse_body(node) + + def _function_helper( + self, + node: ast.FunctionDef | ast.AsyncFunctionDef, + fill_suffix: Literal["def", "async def"], + ) -> None: + self._write_decorators(node) + def_str = fill_suffix + " " + node.name + self.fill(def_str) + if hasattr(node, "type_params"): + self._type_params_helper(node.type_params) + with self.delimit("(", ")"): + self.traverse(node.args) + if node.returns: + self.write("->") + self.traverse(node.returns) + with self.block(extra=self.get_type_comment(node)): + self._write_docstring_and_traverse_body(node) + + def _write_decorators( + self, node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef + ) -> None: + for deco in node.decorator_list: + self.fill("@") + self.traverse(deco) + def _last_char_is(self, char_to_check: str) -> bool: - return len(self._source) > 0 and self._source[-1][-1:] == char_to_check + return bool(self._source) and self._source[-1][-1:] == char_to_check def _get_space_before_write(self) -> str: - if not self._source: - return "" - most_recent_token: str = self._source[-1] - return "" if most_recent_token[-1:] in chars_that_dont_need_whitespace else " " + return ( + "" + if not self._source + or self._source[-1][-1:] in chars_that_dont_need_whitespace + else " " + ) def _get_line_splitter(self) -> Literal["", "\n", ";"]: """Get character that starts the next line of code with the shortest @@ -197,7 +250,7 @@ def _get_line_splitter(self) -> Literal["", "\n", ";"]: if ( self._indent > 0 and self.previous_node_in_body is not None - and node_inlineable(self.previous_node_in_body) + and self._node_inlineable(self.previous_node_in_body) ): return ";" @@ -208,3 +261,23 @@ def _write_comma_delimitated_body( ) -> None: """Writes ast expr objects with comma delimitation""" self.interleave(lambda: self.write(","), self.traverse, body) + + @staticmethod + def _node_inlineable(node: ast.AST) -> bool: + return node.__class__.__name__ in [ + "Assert", + "AnnAssign", + "Assign", + "AugAssign", + "Break", + "Continue", + "Delete", + "Expr", + "Global", + "Import", + "ImportFrom", + "Nonlocal", + "Pass", + "Raise", + "Return", + ] diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index 5e86e08..0bb20e7 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -66,10 +66,10 @@ def __init__(self, config: SkipConfig) -> None: @staticmethod def _within_class_node(function): - def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None: + def wrapper(self: "AstNodeSkipper", *args) -> ast.AST | None: self._within_class = True try: - return function(self, *args, **kwargs) + return function(self, *args) finally: self._within_class = False @@ -77,10 +77,10 @@ def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None: @staticmethod def _within_function_node(function): - def wrapper(self: "AstNodeSkipper", *args, **kwargs) -> ast.AST | None: + def wrapper(self: "AstNodeSkipper", *args) -> ast.AST | None: self._within_function = True try: - return function(self, *args, **kwargs) + return function(self, *args) finally: self._within_function = False @@ -108,7 +108,7 @@ def generic_visit(self, node: ast.AST) -> ast.AST: and not new_values and field == "body" ): - new_values = [ast.Pass()] + new_values.append(ast.Pass()) old_value[:] = new_values @@ -153,14 +153,14 @@ def visit_Module(self, node: ast.Module) -> ast.AST: if self.token_types_config.skip_dangling_expressions: skip_dangling_expressions(node) - module: ast.Module = self.generic_visit(node) # type:ignore + self.generic_visit(node) if self.optimizations_config.remove_unused_imports and self._has_imports: import_filter = UnusedImportSkipper() - import_filter.visit(module) + import_filter.visit(node) self._warn_unused_skips() - return module + return node @_within_class_node def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None: @@ -181,23 +181,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None: return self.generic_visit(node) - @_within_function_node def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST | None: return self._handle_function_node(node) - @_within_function_node def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST | None: return self._handle_function_node(node) - def _should_skip_function( - self, node: ast.FunctionDef | ast.AsyncFunctionDef - ) -> bool: - """If a function node should be skipped""" - return node.name in self.tokens_config.functions_to_skip or ( - self.token_types_config.skip_overload_functions - and is_overload_function(node) - ) - + @_within_function_node def _handle_function_node( self, node: ast.FunctionDef | ast.AsyncFunctionDef ) -> ast.AST | None: @@ -218,10 +208,19 @@ def _handle_function_node( if isinstance(last_body_node, ast.Return) and ( is_return_none(last_body_node) or last_body_node.value is None ): - node.body = node.body[:-1] + node.body.pop() return self.generic_visit(node) + def _should_skip_function( + self, node: ast.FunctionDef | ast.AsyncFunctionDef + ) -> bool: + """Determines if a function node should be skipped.""" + return node.name in self.tokens_config.functions_to_skip or ( + self.token_types_config.skip_overload_functions + and is_overload_function(node) + ) + def visit_Attribute(self, node: ast.Attribute) -> ast.AST | None: if isinstance(node.value, ast.Name): if node.attr in self.optimizations_config.enums_to_fold.get( @@ -376,6 +375,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom | None: self._has_imports = True return node + def visit_alias(self, node: ast.alias) -> ast.alias: + return node + def visit_Name(self, node: ast.Name) -> ast.AST: """Extends super's implementation by adding constant folding""" if node.id in self.optimizations_config.vars_to_fold: @@ -423,6 +425,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.AST | None: def visit_Return(self, node: ast.Return) -> ast.AST: if is_return_none(node): node.value = None + return node return self.generic_visit(node) @@ -431,6 +434,12 @@ def visit_Pass(self, node: ast.Pass) -> None: are populated with a Pass node.""" return None # This could be toggleable + def visit_Break(self, node: ast.Break) -> ast.Break: + return node + + def visit_Continue(self, node: ast.Continue) -> ast.Continue: + return node + def visit_Call(self, node: ast.Call) -> ast.AST | None: if ( self.optimizations_config.assume_this_machine @@ -503,6 +512,8 @@ def visit_BinOp(self, node: ast.BinOp) -> ast.AST: def visit_arg(self, node: ast.arg) -> ast.AST: if self.token_types_config.skip_type_hints: node.annotation = None + return node + return self.generic_visit(node) def visit_arguments(self, node: ast.arguments) -> ast.AST: @@ -565,7 +576,6 @@ def _use_version_optimization(self, min_version: tuple[int, int]) -> bool: def _has_code_to_skip(self) -> bool: return ( self.target_python_version is not None - or len(self.optimizations_config.vars_to_fold) > 0 or self.optimizations_config.has_code_to_skip() or self.tokens_config.has_code_to_skip() or self.token_types_config.has_code_to_skip() @@ -665,14 +675,11 @@ def generic_visit(self, node: ast.AST) -> ast.AST: if value is None: ast_removed = True continue - elif not isinstance(value, ast.AST): - new_values.extend(value) - continue new_values.append(value) if not isinstance(node, ast.Module) and not new_values and ast_removed: - new_values = [ast.Pass()] + new_values.append(ast.Pass()) old_value[:] = reversed(new_values) diff --git a/personal_python_ast_optimizer/parser/utils.py b/personal_python_ast_optimizer/parser/utils.py index 728b708..9fa3e88 100644 --- a/personal_python_ast_optimizer/parser/utils.py +++ b/personal_python_ast_optimizer/parser/utils.py @@ -35,24 +35,6 @@ def is_return_none(node: ast.Return) -> bool: return isinstance(node.value, ast.Constant) and node.value.value is None -def node_inlineable(node: ast.AST) -> bool: - return node.__class__.__name__ in [ - "Assert", - "AnnAssign", - "Assign", - "AugAssign", - "Break", - "Continue", - "Delete", - "Expr", - "Import", - "ImportFrom", - "Pass", - "Raise", - "Return", - ] - - def skip_dangling_expressions( node: ast.Module | ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef, ) -> None: diff --git a/tests/parser/test_global.py b/tests/parser/test_global.py new file mode 100644 index 0000000..8c19e77 --- /dev/null +++ b/tests/parser/test_global.py @@ -0,0 +1,35 @@ +from tests.utils import BeforeAndAfter, run_minifier_and_assert_correct + + +def test_global_same_line(): + before_and_after = BeforeAndAfter( + """ +a = 1 +def test(): + global a + print(a) +""", + "a=1\ndef test():global a;print(a)", + ) + + run_minifier_and_assert_correct(before_and_after) + + +def test_nonlocal_same_line(): + before_and_after = BeforeAndAfter( + """ +def test(): + x = 1 + def i(): + nonlocal x + print(x) + i() +""", + """ +def test(): +\tx=1 +\tdef i():nonlocal x;print(x) +\ti()""".strip(), + ) + + run_minifier_and_assert_correct(before_and_after) diff --git a/version.txt b/version.txt index 91ff572..26d99a2 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -5.2.0 +5.2.1