diff --git a/personal_python_ast_optimizer/parser/config.py b/personal_python_ast_optimizer/parser/config.py index 5a90ce3..4386119 100644 --- a/personal_python_ast_optimizer/parser/config.py +++ b/personal_python_ast_optimizer/parser/config.py @@ -4,6 +4,17 @@ from types import EllipsisType +class TypeHintsToSkip(Enum): + NONE = 0 + # ALL might be unsafe, NamedTuple or TypedDict for example + ALL = 1 + # Should be safe in most cases + ALL_BUT_CLASS_VARS = 2 + + def __bool__(self) -> bool: + return self != TypeHintsToSkip.NONE + + class TokensToSkip(dict[str, int]): __slots__ = ("token_type",) @@ -94,6 +105,7 @@ def get_missing_tokens_iter(self) -> Iterator[tuple[str, str]]: class TokenTypesConfig(_Config): __slots__ = ( + "simplify_named_tuples", "skip_dangling_expressions", "skip_type_hints", "skip_overload_functions", @@ -103,12 +115,14 @@ def __init__( self, *, skip_dangling_expressions: bool = True, - skip_type_hints: bool = True, + skip_type_hints: TypeHintsToSkip = TypeHintsToSkip.ALL_BUT_CLASS_VARS, skip_overload_functions: bool = False, + simplify_named_tuples: bool = False, ) -> None: self.skip_dangling_expressions: bool = skip_dangling_expressions - self.skip_type_hints: bool = skip_type_hints + self.skip_type_hints: TypeHintsToSkip = skip_type_hints self.skip_overload_functions: bool = skip_overload_functions + self.simplify_named_tuples: bool = simplify_named_tuples class OptimizationsConfig(_Config): diff --git a/personal_python_ast_optimizer/parser/run.py b/personal_python_ast_optimizer/parser/run.py index fe39cd1..9364371 100644 --- a/personal_python_ast_optimizer/parser/run.py +++ b/personal_python_ast_optimizer/parser/run.py @@ -5,12 +5,17 @@ from personal_python_ast_optimizer.parser.skipper import AstNodeSkipper -def run_minify_parser( - parser: MinifyUnparser, source: str, skip_config: SkipConfig | None = None +def run_unparser( + source: str, + unparser: ast._Unparser | None = None, # type: ignore[name-defined] + skip_config: SkipConfig | None = None, ) -> str: module: ast.Module = ast.parse(source) if skip_config is not None: AstNodeSkipper(skip_config).visit(module) - return parser.visit(module) + if unparser is None: + unparser = MinifyUnparser() + + return unparser.visit(module) diff --git a/personal_python_ast_optimizer/parser/skipper.py b/personal_python_ast_optimizer/parser/skipper.py index 83d2361..627424f 100644 --- a/personal_python_ast_optimizer/parser/skipper.py +++ b/personal_python_ast_optimizer/parser/skipper.py @@ -8,6 +8,7 @@ SkipConfig, TokensConfig, TokenTypesConfig, + TypeHintsToSkip, ) from personal_python_ast_optimizer.parser.machine_info import ( machine_dependent_attributes, @@ -37,6 +38,7 @@ class AstNodeSkipper(ast.NodeTransformer): __slots__ = ( "_has_imports", "_node_context_skippable_futures", + "_simplified_named_tuple", "module_name", "target_python_version", "optimizations_config", @@ -54,6 +56,7 @@ def __init__(self, config: SkipConfig) -> None: self.tokens_config: TokensConfig = config.tokens_config self._has_imports: bool = False + self._simplified_named_tuple: bool = False self._node_context: _NodeContext = _NodeContext.NONE self._skippable_futures: list[str] = ( @@ -156,6 +159,21 @@ def visit_Module(self, node: ast.Module) -> ast.AST: self.generic_visit(node) + if self._simplified_named_tuple: + import_to_update: ast.ImportFrom | None = None + for n in node.body: + if isinstance(n, ast.ImportFrom) and n.module == "collections": + if any(alias.name == "namedtuple" for alias in n.names): + break + if import_to_update is None: + import_to_update = n + else: # namedtuple was not already imported + alias = ast.alias("namedtuple") + if import_to_update is None: + node.body.insert(0, ast.ImportFrom("collections", [alias], 0)) + else: + import_to_update.names.append(alias) + if self.optimizations_config.remove_unused_imports and self._has_imports: import_filter = UnusedImportSkipper() import_filter.visit(node) @@ -174,11 +192,42 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST | None: if self.token_types_config.skip_dangling_expressions: skip_dangling_expressions(node) + if ( + self.token_types_config.simplify_named_tuples + and self._is_simple_named_tuple(node) + ): + self._simplified_named_tuple = True + named_tuple = ast.Call( + ast.Name("namedtuple"), + [ + ast.Constant(node.name), + ast.List([ast.Constant(n.target.id) for n in node.body]), # type: ignore + ], + [], + ) + return ast.Assign([ast.Name(node.name)], named_tuple) + skip_base_classes(node, self.tokens_config.classes_to_skip) skip_decorators(node, self.tokens_config.decorators_to_skip) return self.generic_visit(node) + @staticmethod + def _is_simple_named_tuple(node: ast.ClassDef) -> bool: + return ( + len(node.bases) == 1 + and isinstance(node.bases[0], ast.Name) + and node.bases[0].id == "NamedTuple" + and not node.keywords + and not node.decorator_list + and all( + isinstance(n, ast.AnnAssign) + and isinstance(n.target, ast.Name) + and n.value is None + for n in node.body + ) + ) + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST | None: return self._handle_function_node(node) @@ -347,7 +396,12 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST | None: parsed_node: ast.AnnAssign = self.generic_visit(node) # type: ignore if self.token_types_config.skip_type_hints: - if not parsed_node.value and self._node_context == _NodeContext.CLASS: + if ( + not parsed_node.value + and self._node_context == _NodeContext.CLASS + and self.token_types_config.skip_type_hints + == TypeHintsToSkip.ALL_BUT_CLASS_VARS + ): parsed_node.annotation = ast.Name("int") elif parsed_node.value is None: return None diff --git a/tests/parser/test_class.py b/tests/parser/test_class.py index 98c3b4f..368063d 100644 --- a/tests/parser/test_class.py +++ b/tests/parser/test_class.py @@ -1,4 +1,7 @@ -from personal_python_ast_optimizer.parser.config import TokenTypesConfig +from personal_python_ast_optimizer.parser.config import ( + TokenTypesConfig, + TypeHintsToSkip, +) from tests.utils import BeforeAndAfter, run_minifier_and_assert_correct @@ -13,7 +16,7 @@ class Foo(): run_minifier_and_assert_correct(before_and_after) -def test_tuple_class(): +def test_class_preserves_type_hints(): before_and_after = BeforeAndAfter( """ class SomeTuple(): @@ -32,5 +35,31 @@ class B: \t\treturn B""", ) run_minifier_and_assert_correct( - before_and_after, token_types_config=TokenTypesConfig(skip_type_hints=True) + before_and_after, + token_types_config=TokenTypesConfig( + skip_type_hints=TypeHintsToSkip.ALL_BUT_CLASS_VARS + ), + ) + + +def test_class_skip_type_hints(): + before_and_after = BeforeAndAfter( + """ +class SomeTuple(): + '''A tuple, wow!''' + thing1: str + thing2: int + + def a(): + class B: + thing3: None + return B""", + """class SomeTuple: +\tdef a(): +\t\tclass B:pass +\t\treturn B""", + ) + run_minifier_and_assert_correct( + before_and_after, + token_types_config=TokenTypesConfig(skip_type_hints=TypeHintsToSkip.ALL), ) diff --git a/tests/parser/test_imports.py b/tests/parser/test_imports.py index bcc914e..e80cae4 100644 --- a/tests/parser/test_imports.py +++ b/tests/parser/test_imports.py @@ -3,6 +3,7 @@ from personal_python_ast_optimizer.parser.config import ( OptimizationsConfig, TokenTypesConfig, + TypeHintsToSkip, ) from tests.utils import BeforeAndAfter, run_minifier_and_assert_correct @@ -21,13 +22,13 @@ @pytest.mark.parametrize( "version,skip_type_hints,after", [ - (None, False, _futures_imports_inline), - ((3, 7), False, "from __future__ import annotations"), - ((3, 7), True, ""), + (None, TypeHintsToSkip.NONE, _futures_imports_inline), + ((3, 7), TypeHintsToSkip.NONE, "from __future__ import annotations"), + ((3, 7), TypeHintsToSkip.ALL, ""), ], ) def test_futures_imports( - version: tuple[int, int] | None, skip_type_hints: bool, after: str + version: tuple[int, int] | None, skip_type_hints: TypeHintsToSkip, after: str ): before_and_after = BeforeAndAfter(_futures_imports, after) diff --git a/tests/parser/test_tuple.py b/tests/parser/test_tuple.py index 35605ca..0a3171f 100644 --- a/tests/parser/test_tuple.py +++ b/tests/parser/test_tuple.py @@ -1,3 +1,6 @@ +import pytest + +from personal_python_ast_optimizer.parser.config import TokenTypesConfig from tests.utils import BeforeAndAfter, run_minifier_and_assert_correct @@ -11,3 +14,50 @@ def test_tuple_whitespace(): ) run_minifier_and_assert_correct(before_and_after) + + +_simplify_named_tuple_test_cases: list[tuple[str, str]] = [ + ( + """ +from typing import NamedTuple + +class A(NamedTuple): + foo: int + bar: str +""", + "from collections import namedtuple\nA=namedtuple('A',['foo','bar'])", + ), + ( + """ +from collections import namedtuple +from typing import NamedTuple + +class A(NamedTuple): + foo: int + bar: str +""", + "from collections import namedtuple\nA=namedtuple('A',['foo','bar'])", + ), + ( + """ +from collections import OrderedDict +from typing import NamedTuple + +class A(NamedTuple): + foo: int + bar: str +b=OrderedDict() +""", + "from collections import OrderedDict,namedtuple\nA=namedtuple('A',['foo','bar'])\nb=OrderedDict()", # noqa: E501 + ), +] + + +@pytest.mark.parametrize(("before", "after"), _simplify_named_tuple_test_cases) +def test_simplify_named_tuple(before: str, after: str): + before_and_after = BeforeAndAfter(before, after) + + run_minifier_and_assert_correct( + before_and_after, + token_types_config=TokenTypesConfig(simplify_named_tuples=True), + ) diff --git a/tests/utils.py b/tests/utils.py index 6ee8bd4..86a7fd5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,8 +6,7 @@ TokensConfig, TokenTypesConfig, ) -from personal_python_ast_optimizer.parser.minifier import MinifyUnparser -from personal_python_ast_optimizer.parser.run import run_minify_parser +from personal_python_ast_optimizer.parser.run import run_unparser class BeforeAndAfter: @@ -27,12 +26,9 @@ def run_minifier_and_assert_correct( tokens_config: TokensConfig | None = None, optimizations_config: OptimizationsConfig | None = None, ): - unparser: MinifyUnparser = MinifyUnparser() - - minified_code: str = run_minify_parser( - unparser, + minified_code: str = run_unparser( before_and_after.before, - SkipConfig( + skip_config=SkipConfig( "", target_python_version=target_python_version, tokens_config=tokens_config or TokensConfig(), diff --git a/version.txt b/version.txt index 74664af..09b254e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -5.3.3 +6.0.0