diff --git a/blark/html.py b/blark/html.py
index aab7081..e44e8c8 100644
--- a/blark/html.py
+++ b/blark/html.py
@@ -1,128 +1,156 @@
"""Syntax-highlighted HTML file writer."""
from __future__ import annotations
-import collections
import dataclasses
import pathlib
-from typing import Any, DefaultDict, Dict, Generator, List, Optional
+from typing import Any, Optional, Union
import lark
+from lxml import etree
from .output import OutputBlock, register_output_handler
@dataclasses.dataclass(frozen=True)
-class HighlighterAnnotation:
- """A single HTML tag annotation which applies to a position range of source code."""
+class CodeSpan:
+ #: Start position (inclusive)
+ start: int
+ #: End position (exclusive)
+ end: int
- name: str
- terminal: bool
- is_open_tag: bool
- other_tag_pos: int
+ @staticmethod
+ def of(item: Union[lark.Tree, lark.Token]) -> Optional[CodeSpan]:
+ """Get the code span of the given Tree or Token"""
+ ref = item.meta if isinstance(item, lark.Tree) else item
+ if (start := getattr(ref, 'start_pos', None)) is not None:
+ if (end := getattr(ref, 'end_pos', None)) is not None:
+ return CodeSpan(start, end)
+ return None
+
+ def contains(self, other: CodeSpan) -> bool:
+ """Determine if the code span contains another code span"""
+ return other.start >= self.start and other.end <= self.end
+
+
+class RemoveNone(lark.Transformer):
+ """Remove all None elements from a lark tree"""
+ def __default__(
+ self,
+ data: lark.Token,
+ children: list,
+ meta: lark.tree.Meta
+ ) -> lark.Tree:
+ return lark.Tree(data, [c for c in children if c is not None], meta)
- def __str__(self) -> str:
- return self.as_string()
- def as_string(self, tag: str = "span") -> str:
- # some options here?
- if not self.is_open_tag:
- return f'{tag}>'
+@dataclasses.dataclass(frozen=True)
+class TokenPlacement:
+ #: Token's parent tree
+ tree: lark.Tree
+ #: Index of the token in tree.children
+ child_index: int
+
+
+def find_placement(
+ token: lark.Token,
+ tree: lark.Tree
+) -> Optional[TokenPlacement]:
+ """Find a placement for the token within the given tree according to token's code span."""
+ tree_span = CodeSpan.of(tree)
+ token_span = CodeSpan.of(token)
+
+ if not (tree_span and token_span and tree_span.contains(token_span)):
+ return None
+
+ for child in tree.children:
+ if isinstance(child, lark.Tree):
+ if placement := find_placement(token, child):
+ return placement
+
+ for index, child in enumerate(tree.children):
+ child_span = CodeSpan.of(child)
+ if child_span and child_span.start > token_span.start:
+ return TokenPlacement(tree, index)
+
+ return TokenPlacement(tree, len(tree.children))
+
+
+def insert_comments(
+ tree: lark.Tree,
+ comments: list[lark.Token],
+ source_code: str
+) -> lark.Tree:
+ """Combine code and comments into a single lark tree"""
+ if not comments:
+ return tree
+
+ all_spans = [CodeSpan.of(tree)] + [CodeSpan.of(token) for token in comments]
+ valid_spans = [span for span in all_spans if span]
+
+ meta = lark.tree.Meta()
+ meta.start_pos = min(span.start for span in valid_spans)
+ meta.end_pos = max(span.end for span in valid_spans)
+
+ # Use no-op transformer to create a copy of the tree.
+ new_tree: lark.Tree = lark.Transformer().transform(tree)
+ # Assign metadata to the copied tree.
+ new_tree = lark.Tree(new_tree.data, new_tree.children, meta)
+
+ for comment in comments:
+ if placement := find_placement(comment, new_tree):
+ placement.tree.children.insert(placement.child_index, comment)
+
+ return new_tree
+
+
+def html_element(*, cls: str, text: str = None) -> etree.Element:
+ """Create a HTML element with the given class"""
+ element = etree.Element("span")
+ element.set("class", cls)
+ if text is not None:
+ element.text = text
+ return element
+
+
+def lark_item_to_element(
+ item: Union[lark.Tree, lark.Token],
+ source_code: str,
+ output_pos: int
+) -> tuple[etree.Element, int]:
+ """
+ Convert a lark item to an lxml element. Return the element and the position in the source code
+ after this item.
+ """
+ code_span = CodeSpan.of(item) or CodeSpan(output_pos, output_pos)
- if self.terminal:
- classes = " ".join(("term", self.name))
- else:
- classes = " ".join(("rule", self.name))
+ if isinstance(item, lark.Tree):
+ element = html_element(cls=item.data)
+ running_pos = code_span.start
- return f'<{tag} class="{classes}">'
+ def append_text_up_to(pos: int):
+ nonlocal running_pos
+ if text := source_code[running_pos:pos]:
+ if len(element) == 0:
+ element.text = text
+ else:
+ element[-1].tail = text
+ running_pos = pos
+ for child in item.children:
+ if child_span := CodeSpan.of(child):
+ append_text_up_to(child_span.start)
-def _add_annotation_pair(
- annotations: DefaultDict[int, List[HighlighterAnnotation]],
- name: str,
- start_pos: int,
- end_pos: int,
- terminal: bool,
-) -> None:
- """
- Add a pair of HTML tag annotations to the position-indexed list.
-
- Parameters
- ----------
- annotations : DefaultDict[int, List[HighlighterAnnotation]]
- Annotations keyed by 0-indexed string position.
- name : str
- Name of the annotation.
- start_pos : int
- String index position which the annotation applies to.
- end_pos : int
- String index position which the annotation ends at.
- terminal : bool
- Whether this is a TERMINAL (True) or a rule (false).
- """
- annotations[start_pos].append(
- HighlighterAnnotation(
- name=name,
- terminal=terminal,
- is_open_tag=True,
- other_tag_pos=end_pos,
- )
- )
- annotations[end_pos].append(
- HighlighterAnnotation(
- name=name,
- terminal=terminal,
- is_open_tag=False,
- other_tag_pos=start_pos,
- )
- )
-
-
-def get_annotations(tree: lark.Tree) -> DefaultDict[int, List[HighlighterAnnotation]]:
- """Get annotations for syntax elements in the given parse tree."""
- annotations: DefaultDict[int, List[HighlighterAnnotation]] = collections.defaultdict(
- list
- )
-
- for subtree in tree.iter_subtrees():
- if hasattr(subtree.meta, "start_pos"):
- _add_annotation_pair(
- annotations,
- name=subtree.data,
- terminal=False,
- start_pos=subtree.meta.start_pos,
- end_pos=subtree.meta.end_pos,
- )
- for child in subtree.children:
- if isinstance(child, lark.Token):
- if child.start_pos is not None and child.end_pos is not None:
- _add_annotation_pair(
- annotations,
- name=child.type,
- terminal=True,
- start_pos=child.start_pos,
- end_pos=child.end_pos,
- )
- return annotations
-
-
-def apply_annotations_to_code(
- code: str,
- annotations: Dict[int, List[HighlighterAnnotation]]
-) -> str:
- def annotate() -> Generator[str, None, None]:
- pos = 0
- for pos, ch in enumerate(code):
- for ann in reversed(annotations.get(pos, [])):
- yield str(ann)
- if ch == " ":
- yield " "
- else:
- yield ch
-
- for ann in annotations.get(pos + 1, []):
- yield str(ann)
-
- return "".join(annotate())
+ child_element, running_pos = lark_item_to_element(child, source_code, running_pos)
+ element.append(child_element)
+
+ append_text_up_to(code_span.end)
+ else:
+ assert isinstance(item, lark.Token)
+ element = html_element(
+ cls=item.type,
+ text=source_code[code_span.start:code_span.end])
+
+ return element, code_span.end
@dataclasses.dataclass
@@ -131,35 +159,44 @@ class HtmlWriter:
source_filename: Optional[pathlib.Path]
block: OutputBlock
- @property
- def source_code(self) -> str:
- """The source code associated with the block."""
- assert self.block.origin is not None
- return self.block.origin.source_code
-
def to_html(self) -> str:
- """HTML tag-annotated source code."""
- assert self.block.origin is not None
- assert self.block.origin.tree is not None
- annotations = get_annotations(self.block.origin.tree)
-
- for comment in self.block.origin.comments:
- if comment.start_pos is not None and comment.end_pos is not None:
- _add_annotation_pair(
- annotations,
- name=comment.type,
- start_pos=comment.start_pos,
- end_pos=comment.end_pos,
- terminal=True,
- )
-
- return apply_annotations_to_code(self.source_code, annotations)
+ """Format source code as HTML"""
+ origin = self.block.origin
+ assert origin is not None
+ assert origin.tree is not None
+
+ cleaned_tree = RemoveNone().transform(origin.tree)
+ tree_with_comments = insert_comments(cleaned_tree, origin.comments, origin.source_code)
+
+ root = html_element(cls="blark")
+
+ origin_section = html_element(
+ cls="blark_origin",
+ text=origin.identifier or '')
+ root.append(origin_section)
+
+ code_section = html_element(cls="blark_code")
+
+ # Handle any leading text before the tree
+ tree_span = CodeSpan.of(tree_with_comments) or CodeSpan(0, 0)
+ if leading := origin.source_code[0:tree_span.start]:
+ code_section.text = leading
+
+ # Convert code tree to lxml tree.
+ element, _ = lark_item_to_element(
+ tree_with_comments, origin.source_code, tree_span.start
+ )
+ code_section.append(element)
+ root.append(code_section)
+
+ # Serialize to HTML.
+ return etree.tostring(root, encoding='unicode', method='html')
@staticmethod
def save(
user: Any,
source_filename: Optional[pathlib.Path],
- parts: List[OutputBlock],
+ parts: list[OutputBlock],
) -> str:
"""Convert the source code block to HTML and return it."""
result = []