diff --git a/src/unimport/refactor/__init__.py b/src/unimport/refactor/__init__.py new file mode 100644 index 00000000..78ddccae --- /dev/null +++ b/src/unimport/refactor/__init__.py @@ -0,0 +1,42 @@ +import dataclasses +import tokenize +from typing import List, Union + +import libcst as cst + +from unimport.refactor.tokenize_utils import comments_to_strings, prepare_tokens, string_to_comments +from unimport.refactor.transformer import RemoveUnusedImportTransformer +from unimport.statement import Import, ImportFrom + +__all__ = ("refactor_string",) + + +@dataclasses.dataclass +class _Refactor: + source: str + unused_imports: List[Union[Import, ImportFrom]] + + def __post_init__(self): + self.tokens = prepare_tokens(self.source) + + def refactor_string(self, source: str) -> str: + if self.unused_imports: + wrapper = cst.MetadataWrapper(cst.parse_module(source)) + remove_unused_import_transformer = RemoveUnusedImportTransformer(self.unused_imports) + fixed_module = wrapper.visit(remove_unused_import_transformer) + return fixed_module.code + + return source + + def __call__(self, *args, **kwargs) -> str: + source_without_comments = tokenize.untokenize(comments_to_strings(self.tokens)) + refactored_source_without_comments = self.refactor_string(source_without_comments) + if refactored_source_without_comments != self.source: + return tokenize.untokenize(string_to_comments(prepare_tokens(refactored_source_without_comments))) + + return self.source + + +def refactor_string(source: str, unused_imports: List[Union[Import, ImportFrom]]) -> str: + refactor = _Refactor(source, unused_imports) + return refactor() diff --git a/src/unimport/refactor/tokenize_utils.py b/src/unimport/refactor/tokenize_utils.py new file mode 100644 index 00000000..801f38c9 --- /dev/null +++ b/src/unimport/refactor/tokenize_utils.py @@ -0,0 +1,157 @@ +import ast +import token +import tokenize +from enum import Enum +from typing import Iterator, List + +__all__ = ( + "generate_tokens", + "set_tokens_parent", + "set_tokens_child", + "comment_token", + "string_token", + "Position", + "get_child_tokens", + "first_child_token_match", + "pass_token", + "comments_to_strings", + "prepare_tokens", + "string_to_comments", +) + + +def generate_tokens(source: str) -> Iterator[tokenize.TokenInfo]: + return tokenize.generate_tokens(readline=iter(ast._splitlines_no_ff(source)).__next__) # type: ignore # noqa + + +def set_tokens_parent(tokens: List[tokenize.TokenInfo]) -> None: + parent = None + for tok in tokens: + setattr(tok, "parent", parent) + parent = tok + + +def set_tokens_child(tokens: List[tokenize.TokenInfo]) -> None: + for index, tok in enumerate(tokens): + try: + setattr(tok, "child", tokens[index + 1]) + except IndexError: + setattr(tok, "child", None) + + +def get_child_tokens(t: tokenize.TokenInfo) -> Iterator[tokenize.TokenInfo]: + child = t + while child: + child = child.child # type: ignore + if child: + yield child + + +def first_child_token_match(t: tokenize.TokenInfo): # TODO: rename, nereye pass eklemeliyim ? + x_offset = t.start[1] + + for child in get_child_tokens(t): + if all( + ( + child.start[1] == x_offset, + child.type + not in [ + token.NL, + token.NEWLINE, + token.ENDMARKER, + token.INDENT, + token.DEDENT, + # token.TYPE_COMMENT, + # token.COMMENT + ], + not child.string.startswith('"#'), + ) + ): + return child + if child.start[1] < x_offset: + break + + return None + + +def comment_token(tok: tokenize.TokenInfo) -> tokenize.TokenInfo: + assert tok.type == token.STRING + + line = " ".join(tok.line.split()) + "\n" if tok.line[-1] == "\n" else " ".join(tok.line.split()) + + return tok._replace( + type=token.COMMENT, + string=tok.string.replace('"', ""), + start=(tok.start[0], tok.start[1]), + end=(tok.end[0], tok.end[1] - 2), + line=line.replace('"', ""), + ) + + +def string_token(tok: tokenize.TokenInfo) -> tokenize.TokenInfo: + assert tok.type == token.COMMENT + + line = f'{" ".join(tok.line.split())}\n' if tok.line[-1] == "\n" else f'{" ".join(tok.line.split())}' + + return tok._replace(type=token.STRING, string=f'"{tok.string}"', end=(tok.end[0], tok.end[1]), line=line) + + +def pass_token(tok: tokenize.TokenInfo) -> tokenize.TokenInfo: + return tok._replace( + type=token.NAME, + string="pass", + # start=(tok.start[0] - 1 , tok.start[1]), + # end=(tok.end[0] , tok.end[1]), + # line=f'{tok.start[1] * " "}pass\n', + ) + + +class Position(int, Enum): + LINE = 0 + COLUMN = 1 + + +def increase(t: tokenize.TokenInfo, amount: int = 1, page: Position = Position.LINE) -> tokenize.TokenInfo: + if amount == 0: + return t + + start, end = list(t.start), list(t.end) + + start[page] += amount + end[page] += amount + + return t._replace(start=tuple(start), end=tuple(end)) + + +def prepare_tokens(source: str) -> List[tokenize.TokenInfo]: + tokens = list(generate_tokens(source)) + set_tokens_parent(tokens) + set_tokens_child(tokens) + + return tokens + + +def comments_to_strings(tokens: Iterator[tokenize.TokenInfo]) -> Iterator[tokenize.TokenInfo]: + return ( + string_token(tok) + if ( + tok.type == tokenize.COMMENT + and tok.parent # type: ignore + and (tok.parent.type == token.NL or tok.parent.type == token.NEWLINE) # type: ignore + ) + else tok + for tok in tokens + ) + + +def string_to_comments(tokens: Iterator[tokenize.TokenInfo]) -> Iterator[tokenize.TokenInfo]: + amount = 0 + for t in tokens: + if t.type == tokenize.STRING and t.parent and t.string.startswith('"#'): # type: ignore + if not first_child_token_match(t): + amount += 1 + yield increase(pass_token(t), amount) + else: + yield increase(comment_token(t), amount) + else: + yield increase(t, amount) diff --git a/src/unimport/refactor.py b/src/unimport/refactor/transformer.py similarity index 92% rename from src/unimport/refactor.py rename to src/unimport/refactor/transformer.py index 3d0293ec..6dd7b4de 100644 --- a/src/unimport/refactor.py +++ b/src/unimport/refactor/transformer.py @@ -7,10 +7,10 @@ from unimport import typing as T from unimport.statement import Import, ImportFrom -__all__ = ("refactor_string",) +__all__ = ("RemoveUnusedImportTransformer",) -class _RemoveUnusedImportTransformer(cst.CSTTransformer): +class RemoveUnusedImportTransformer(cst.CSTTransformer): __slots__ = ("unused_imports",) METADATA_DEPENDENCIES = (PositionProvider,) @@ -121,12 +121,3 @@ def get_star_imp() -> Optional[ImportFrom]: return original_node return self.leave_import_alike(original_node, updated_node) - - -def refactor_string(source: str, unused_imports: List[Union[Import, ImportFrom]]) -> str: - if unused_imports: - wrapper = cst.MetadataWrapper(cst.parse_module(source)) - fixed_module = wrapper.visit(_RemoveUnusedImportTransformer(unused_imports)) - return fixed_module.code - - return source diff --git a/tests/cases/analyzer/comments/issue_100.py b/tests/cases/analyzer/comments/issue_100.py new file mode 100644 index 00000000..3bea42dc --- /dev/null +++ b/tests/cases/analyzer/comments/issue_100.py @@ -0,0 +1,23 @@ +from typing import List, Union + +from unimport.statement import Import, ImportFrom, Name + +__all__ = ["NAMES", "IMPORTS", "UNUSED_IMPORTS"] + + +NAMES: List[Name] = [ + Name(lineno=11, name="os", is_all=False), + Name(lineno=11, name="Union", is_all=False), + Name(lineno=11, name="driver", is_all=False), + Name(lineno=11, name="Grammar", is_all=False), +] +IMPORTS: List[Union[Import, ImportFrom]] = [ + Import(lineno=2, column=1, name="os", package="os"), + ImportFrom(lineno=4, column=1, name="Union", package="typing", star=False, suggestions=[]), + ImportFrom(lineno=7, column=1, name="token", package=".pgen2", star=False, suggestions=[]), + ImportFrom(lineno=8, column=1, name="driver", package=".pgen2", star=False, suggestions=[]), + ImportFrom(lineno=10, column=1, name="Grammar", package=".pgen2.grammar", star=False, suggestions=[]), +] +UNUSED_IMPORTS: List[Union[Import, ImportFrom]] = [ + ImportFrom(lineno=7, column=1, name="token", package=".pgen2", star=False, suggestions=[]) +] diff --git a/tests/cases/refactor/comments/issue_100.py b/tests/cases/refactor/comments/issue_100.py new file mode 100644 index 00000000..300a84ff --- /dev/null +++ b/tests/cases/refactor/comments/issue_100.py @@ -0,0 +1,10 @@ +# Python imports +import os + +from typing import Union + +# Local imports +from .pgen2 import driver + +from .pgen2.grammar import Grammar +os, Union, driver, Grammar diff --git a/tests/cases/source/comments/issue_100.py b/tests/cases/source/comments/issue_100.py new file mode 100644 index 00000000..6e00c20a --- /dev/null +++ b/tests/cases/source/comments/issue_100.py @@ -0,0 +1,11 @@ +# Python imports +import os + +from typing import Union + +# Local imports +from .pgen2 import token +from .pgen2 import driver + +from .pgen2.grammar import Grammar +os, Union, driver, Grammar diff --git a/tests/refactor/__init__.py b/tests/refactor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/refactor/test_tokenize_utils.py b/tests/refactor/test_tokenize_utils.py new file mode 100644 index 00000000..0263d747 --- /dev/null +++ b/tests/refactor/test_tokenize_utils.py @@ -0,0 +1,99 @@ +import textwrap +import token +from tokenize import TokenInfo + +from unimport.refactor.tokenize_utils import ( + first_child_token_match, + generate_tokens, + get_child_tokens, + set_tokens_child, + set_tokens_parent, +) + + +def test_generate_tokens(): + source = textwrap.dedent( + """ + def foo(): + pass + """ + ) + tokens = generate_tokens(source) + assert next(tokens) == TokenInfo(type=token.NL, string="\n", start=(1, 0), end=(1, 1), line="\n") + assert next(tokens) == TokenInfo(type=token.NAME, string="def", start=(2, 0), end=(2, 3), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.NAME, string="foo", start=(2, 4), end=(2, 7), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.OP, string="(", start=(2, 7), end=(2, 8), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.OP, string=")", start=(2, 8), end=(2, 9), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.OP, string=":", start=(2, 9), end=(2, 10), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.NEWLINE, string="\n", start=(2, 10), end=(2, 11), line="def foo():\n") + assert next(tokens) == TokenInfo(type=token.INDENT, string=" ", start=(3, 0), end=(3, 4), line=" pass\n") + assert next(tokens) == TokenInfo(type=token.NAME, string="pass", start=(3, 4), end=(3, 8), line=" pass\n") + assert next(tokens) == TokenInfo(type=token.NEWLINE, string="\n", start=(3, 8), end=(3, 9), line=" pass\n") + assert next(tokens) == TokenInfo(type=token.DEDENT, string="", start=(4, 0), end=(4, 0), line="") + assert next(tokens) == TokenInfo(type=token.ENDMARKER, string="", start=(4, 0), end=(4, 0), line="") + + +def test_set_tokens_parent(): + source = textwrap.dedent( + """ + def foo(): + pass + """ + ) + tokens = list(generate_tokens(source)) + set_tokens_parent(tokens) + + assert tokens[0].parent is None + for index in range(1, len(tokens)): + assert tokens[index].parent == tokens[index - 1] + + +def test_set_tokens_child(): + source = textwrap.dedent( + """ + def foo(): + pass + """ + ) + tokens = list(generate_tokens(source)) + set_tokens_child(tokens) + + assert tokens[-1].child is None + for index in range(len(tokens) - 1): + assert tokens[index].child == tokens[index + 1] + + +def test_get_child_tokens(): + source = textwrap.dedent( + """ + def foo(): + pass + """ + ) + tokens = list(generate_tokens(source)) + set_tokens_child(tokens) + + for index in range(len(tokens)): + assert list(get_child_tokens(tokens[index])) == tokens[index + 1 :] + + +def test_first_child_token_match(): + source = textwrap.dedent( + """ + # comment 1 + + def foo(): # comment 2 + # comment 3 + pass # comment 4 + """ + ) + tokens = list(generate_tokens(source)) + set_tokens_child(tokens) + + assert first_child_token_match(tokens[0]) == TokenInfo( + type=token.COMMENT, string="# comment 1", start=(2, 0), end=(2, 11), line="# comment 1\n" + ) + assert first_child_token_match(tokens[1]) == TokenInfo( + type=token.NAME, string="def", start=(4, 0), end=(4, 3), line="def foo(): # comment 2\n" + ) + assert first_child_token_match(tokens[2]) is None