Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 131 additions & 27 deletions rich/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,84 @@
_is_single_cell_widths: Callable[[str], bool] = _SINGLE_CELLS.issuperset


def _get_codepoint_cell_size(codepoint: int) -> int:
"""Get the cell size of a codepoint (internal, no grapheme cluster handling).

Args:
codepoint: Unicode codepoint (integer).

Returns:
Number of cells (0, 1 or 2) occupied by that codepoint.
"""
_table = CELL_WIDTHS
lower_bound = 0
upper_bound = len(_table) - 1
index = (lower_bound + upper_bound) // 2
while True:
start, end, width = _table[index]
if codepoint < start:
upper_bound = index - 1
elif codepoint > end:
lower_bound = index + 1
else:
return 0 if width == -1 else width
if upper_bound < lower_bound:
break
index = (lower_bound + upper_bound) // 2
return 1


_ZWJ = 0x200D
_VS16 = 0xFE0F


def _cell_len_grapheme(text: str) -> int:
"""Calculate cell length of text, handling multi-codepoint grapheme clusters.

Handles:
- Variation Selector 16 (U+FE0F): Makes preceding 1-wide char into 2-wide
- Zero-Width Joiner (U+200D): Joins emoji into single 2-wide cluster

Args:
text: The text to measure.

Returns:
The cell width of the text.
"""
total = 0
in_zwj_sequence = False
prev_char_width = 0

for char in text:
codepoint = ord(char)

if codepoint == _VS16:
# Variation Selector 16 requests emoji presentation.
# If the previous character was 1 cell wide, it becomes 2 cells.
if prev_char_width == 1:
total += 1
prev_char_width = 2
continue

if codepoint == _ZWJ:
# Zero-width joiner combines with following character.
in_zwj_sequence = True
continue

char_width = _get_codepoint_cell_size(codepoint)

if in_zwj_sequence:
# Character after ZWJ is part of the same grapheme cluster.
in_zwj_sequence = False
prev_char_width = 0
continue

total += char_width
prev_char_width = char_width

return total


@lru_cache(4096)
def cached_cell_len(text: str) -> int:
"""Get the number of cells required to display text.
Expand All @@ -45,7 +123,7 @@ def cached_cell_len(text: str) -> int:
"""
if _is_single_cell_widths(text):
return len(text)
return sum(map(get_character_cell_size, text))
return _cell_len_grapheme(text)


def cell_len(text: str, _cell_len: Callable[[str], int] = cached_cell_len) -> int:
Expand All @@ -61,7 +139,7 @@ def cell_len(text: str, _cell_len: Callable[[str], int] = cached_cell_len) -> in
return _cell_len(text)
if _is_single_cell_widths(text):
return len(text)
return sum(map(get_character_cell_size, text))
return _cell_len_grapheme(text)


@lru_cache(maxsize=4096)
Expand All @@ -74,23 +152,7 @@ def get_character_cell_size(character: str) -> int:
Returns:
int: Number of cells (0, 1 or 2) occupied by that character.
"""
codepoint = ord(character)
_table = CELL_WIDTHS
lower_bound = 0
upper_bound = len(_table) - 1
index = (lower_bound + upper_bound) // 2
while True:
start, end, width = _table[index]
if codepoint < start:
upper_bound = index - 1
elif codepoint > end:
lower_bound = index + 1
else:
return 0 if width == -1 else width
if upper_bound < lower_bound:
break
index = (lower_bound + upper_bound) // 2
return 1
return _get_codepoint_cell_size(ord(character))


def set_cell_size(text: str, total: int) -> str:
Expand Down Expand Up @@ -142,25 +204,67 @@ def chop_cells(
A list of strings such that each string in the list has cell width
less than or equal to the available width.
"""
_get_character_cell_size = get_character_cell_size
_get_codepoint_cell_size = globals()["_get_codepoint_cell_size"]
lines: list[list[str]] = [[]]

append_new_line = lines.append
append_to_last_line = lines[-1].append

total_width = 0

for character in text:
cell_width = _get_character_cell_size(character)
char_doesnt_fit = total_width + cell_width > width
# We need to handle grapheme clusters as atomic units.
# Collect characters into clusters, then process each cluster.
i = 0
text_len = len(text)
while i < text_len:
# Collect a grapheme cluster starting at position i
cluster_start = i
char = text[i]
codepoint = ord(char)
i += 1

# Get the base character width
cluster_width = _get_codepoint_cell_size(codepoint)

# Look ahead for VS16, ZWJ sequences, skin tone modifiers, etc.
while i < text_len:
next_codepoint = ord(text[i])

if next_codepoint == _VS16:
# Variation Selector 16
if cluster_width == 1:
cluster_width = 2
i += 1
elif next_codepoint == _ZWJ:
# Zero-width joiner - consume ZWJ and the next character
i += 1 # skip ZWJ
if i < text_len:
i += 1 # skip the character after ZWJ
# Look for VS16 after the ZWJ target
while i < text_len and ord(text[i]) == _VS16:
i += 1
elif 0x1F3FB <= next_codepoint <= 0x1F3FF:
# Skin tone modifier - part of the cluster
i += 1
elif 0x20E3 == next_codepoint:
# Combining enclosing keycap
i += 1
elif 0xE0020 <= next_codepoint <= 0xE007F:
# Tags (used in flag sequences like England flag)
i += 1
else:
break

cluster = text[cluster_start:i]
char_doesnt_fit = total_width + cluster_width > width

if char_doesnt_fit:
append_new_line([character])
append_new_line([cluster])
append_to_last_line = lines[-1].append
total_width = cell_width
total_width = cluster_width
else:
append_to_last_line(character)
total_width += cell_width
append_to_last_line(cluster)
total_width += cluster_width

return ["".join(line) for line in lines]

Expand Down