diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index e171532..55315d3 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -88,18 +88,10 @@ def generic_visit(self, node) -> ast.AST: for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): new_values = [] - combined_import: ast.Import | None = None + self._combine_imports(old_value) for value in old_value: if isinstance(value, ast.AST): value = self.visit(value) - if isinstance(value, ast.Import): - if combined_import is None: - combined_import = value - else: - self._ast_import_combine(combined_import, value) - continue - else: - combined_import = None if value is None: continue elif not isinstance(value, ast.AST): @@ -112,9 +104,9 @@ def generic_visit(self, node) -> ast.AST: and not new_values and not isinstance(node, ast.Module) ): - old_value[:] = [ast.Pass()] - else: - old_value[:] = new_values + new_values = [ast.Pass()] + + old_value[:] = new_values elif isinstance(old_value, ast.AST): new_node = self.visit(old_value) @@ -124,6 +116,31 @@ def generic_visit(self, node) -> ast.AST: setattr(node, field, new_node) return node + @staticmethod + def _combine_imports(body: list) -> None: + if not body: + return + + new_body = [body[0]] + + for i in range(1, len(body)): + this_node = body[i] + last_node = new_body[-1] + + if ( + isinstance(this_node, ast.Import) and isinstance(last_node, ast.Import) + ) or ( + isinstance(this_node, ast.ImportFrom) + and isinstance(last_node, ast.ImportFrom) + and this_node.module == last_node.module + and this_node.level == last_node.level + ): + last_node.names += this_node.names + else: + new_body.append(this_node) + + body[:] = new_body + def visit_Module(self, node: ast.Module) -> ast.AST: if not self._has_code_to_skip(): return node @@ -229,7 +246,7 @@ def visit_Assign(self, node: ast.Assign) -> ast.AST | None: # TODO: Currently if a.b.c.d only "c" and "d" are checked var_name: str = get_node_name(node.targets[0]) - parent_var_name: str = get_node_name(getattr(node.targets[0], "value", object)) + parent_var_name: str = get_node_name(getattr(node.targets[0], "value", None)) if ( var_name in self.tokens_config.variables_to_skip @@ -504,7 +521,7 @@ def visit_arguments(self, node: ast.arguments) -> ast.AST: def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: parsed_node: ast.BoolOp = self.generic_visit(node) # type: ignore - if isinstance(parsed_node.op, ast.Or) or isinstance(parsed_node.op, ast.And): + if isinstance(parsed_node.op, (ast.Or, ast.And)): # For And nodes left values that are Truthy and const can be removed # and vice versa remove_if: bool = isinstance(parsed_node.op, ast.And) @@ -632,7 +649,3 @@ def _ast_constants_operation( raise ValueError(f"Invalid operation: {operation.__class__.__name__}") return ast.Constant(result) - - @staticmethod - def _ast_import_combine(target: ast.Import, to_be_combined: ast.Import) -> None: - target.names += to_be_combined.names diff --git a/personal_python_ast_optimizer/parser/utils.py b/personal_python_ast_optimizer/parser/utils.py index 83a96b9..11ec7a7 100644 --- a/personal_python_ast_optimizer/parser/utils.py +++ b/personal_python_ast_optimizer/parser/utils.py @@ -5,7 +5,7 @@ from personal_python_ast_optimizer.parser.config import TokensToSkip -def get_node_name(node: object) -> str: +def get_node_name(node: ast.AST | None) -> str: """Gets id or attr which both can represent var names""" if isinstance(node, ast.Call): node = node.func @@ -74,11 +74,7 @@ def skip_decorators( def remove_duplicate_slots( node: ast.Assign | ast.AnnAssign, warn_duplicates: bool = True ) -> None: - if ( - isinstance(node.value, ast.Tuple) - or isinstance(node.value, ast.List) - or isinstance(node.value, ast.Set) - ): + if isinstance(node.value, (ast.Tuple, ast.List, ast.Set)): found_values: set[str] = set() unique_objects: list[ast.expr] = [] for const_value in node.value.elts: @@ -98,8 +94,9 @@ def remove_duplicate_slots( node.value.elts = unique_objects -def first_occurrence_of_type(data: list, target_type) -> int: +def first_occurrence_of_type(data: list, target_type: type) -> int: for index, element in enumerate(data): if isinstance(element, target_type): return index + return -1 diff --git a/tests/parser/test_imports.py b/tests/parser/test_imports.py index 75a560c..68e5a04 100644 --- a/tests/parser/test_imports.py +++ b/tests/parser/test_imports.py @@ -10,11 +10,15 @@ from __future__ import with_statement """ +_futures_imports_inline: str = ( + "from __future__ import annotations,generator_stop,unicode_literals,with_statement" +) + @pytest.mark.parametrize( "version,skip_type_hints,after", [ - (None, False, _futures_imports.strip()), + (None, False, _futures_imports_inline), ((3, 7), False, "from __future__ import annotations"), ((3, 7), True, ""), ], @@ -37,15 +41,18 @@ def test_import_same_line(): before_and_after = BeforeAndAfter( """ import test +import test2 def i(): import a import d from b import c + from b import d as e + from .b import f print() import e """, - """import test -def i():import a,d;from b import c;print();import e""", + """import test,test2 +def i():import a,d;from b import c,d as e;from .b import f;print();import e""", ) run_minifier_and_assert_correct(before_and_after) diff --git a/version.txt b/version.txt index a1ef0ca..50e2274 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -5.0.2 +5.0.3