Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 179 additions & 92 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
from io import TextIOWrapper
import html as html_module
import json
import logging
import re
import string
import subprocess
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import unicodedata
import os
import nltk
Expand All @@ -19,18 +20,13 @@
nltk.download('averaged_perceptron_tagger_eng', quiet=True)

def is_verb_in_sentence(word, sentence):
# Tokenize the sentence into words
tokens = word_tokenize(sentence)

# Get POS tags for each word in the sentence
tagged_words = pos_tag(tokens)

# Find the POS tag of the target word in the tagged words list
word_lower = word.lower()
for tagged_word, pos in tagged_words:
if tagged_word.lower() == word.lower(): # Case insensitive match
return pos.startswith('VB') # Check if the tag starts with 'VB' (verb)

return False # Word not found in sentence
if tagged_word.lower() == word_lower:
return pos.startswith('VB')
return False


ipa_vowels = "aeiouɑɒæɛɪʊʌɔœøɐɘəɤɨɵɜɞɯɲɳɴɶʉʊʏ"
Expand Down Expand Up @@ -58,6 +54,14 @@ def is_verb_in_sentence(word, sentence):
double_word_with_verb = {"could have": "cʊdə", "should have": "ʃʊdə", "would have": "wʊdə", "going to": "gɑnə"}
# all noun+will can be reduced to x'll, but too hard for me to implement

_double_word_lookup: Dict[str, List[Tuple[str, str, bool]]] = {}
for _orig, _changed in double_word_reductions.items():
_first, _second = _orig.split(" ")
_double_word_lookup.setdefault(_first, []).append((_second, _changed, False))
for _orig, _changed in double_word_with_verb.items():
_first, _second = _orig.split(" ")
_double_word_lookup.setdefault(_first, []).append((_second, _changed, True))

# Most of those are not wrong, we just prefer it like this. These will be replced if appear in a word (good for plural and such):
improved_pronounciations = {
"fæməli":"fæmli",
Expand Down Expand Up @@ -140,23 +144,21 @@ def add_double_word_reductions(ipa_text: str, original_text: str):
removed_words = 0
for i in range(len(original_arr)):
original_word = original_arr[i]
for orig, changed in list(double_word_reductions.items())+list(double_word_with_verb.items()):
first = orig.split(" ")[0]
if original_word == first:
second = orig.split(" ")[1]
if len(original_arr) > i + 1 and original_arr[i+1] == second:
# validate that this is not the last word. last word in sentence don't get reduced
next_char = get_next_char(original_arr, i+1, len(second)-1)
if next_char != "":
if second in ("will", "have", "has") and i + 2 < len(original_arr) and original_arr[i+2] == "not":
if original_word not in _double_word_lookup:
continue
for second, changed, needs_verb in _double_word_lookup[original_word]:
if len(original_arr) > i + 1 and original_arr[i+1] == second:
next_char = get_next_char(original_arr, i+1, len(second)-1)
if next_char != "":
if second in ("will", "have", "has") and i + 2 < len(original_arr) and original_arr[i+2] == "not":
continue
if needs_verb:
if not is_verb_in_sentence(original_arr[i+2], original_text):
continue
if orig in double_word_with_verb:
if not is_verb_in_sentence(original_arr[i+2], original_text):
continue

out_arr[i - removed_words] = changed
del out_arr[i - removed_words + 1]
removed_words += 1
out_arr[i - removed_words] = changed
del out_arr[i - removed_words + 1]
removed_words += 1
return " ".join(out_arr)

def handle_t_d(ipa_text: str):
Expand Down Expand Up @@ -271,18 +273,22 @@ def fix_numbers(text_arr: List[str]):



def run_flite(text: str):
fixed_text = text
# fixed_text = " ".join(fix_numbers(fix_nn(text.lower())))
_flite_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flite', 'bin', 'flite')

def _call_flite(text: str) -> str:
try:
flite_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flite', 'bin', 'flite')
ipa_text = subprocess.check_output([flite_path, "-t", fixed_text, "-i"]).decode('utf-8')
return subprocess.check_output([_flite_path, "-t", text, "-i"]).decode('utf-8')
except OSError:
logging.warning('lex_lookup (from flite) is not installed.')
ipa_text = ''
return ''
except subprocess.CalledProcessError:
logging.warning('Non-zero exit status from lex_lookup.')
ipa_text = ''
return ''

def run_flite(text: str):
fixed_text = text
# fixed_text = " ".join(fix_numbers(fix_nn(text.lower())))
ipa_text = _call_flite(fixed_text)

ipa_text = add_reductions_with_stress(ipa_text, fixed_text)
ipa_text = add_double_word_reductions(ipa_text, fixed_text)
Expand Down Expand Up @@ -357,9 +363,59 @@ def remove_checkpoint(checkpoint_path):
if os.path.exists(checkpoint_path):
os.remove(checkpoint_path)

FLITE_BATCH_SIZE = 32
FLITE_MAX_WORKERS = 8

def _run_flite_batch(texts: List[str]) -> List[Tuple[str, str]]:
with ThreadPoolExecutor(max_workers=FLITE_MAX_WORKERS) as executor:
ipa_results = list(executor.map(_call_flite, texts))
results = []
for fixed_text, ipa_text in zip(texts, ipa_results):
ipa_text = add_reductions_with_stress(ipa_text, fixed_text)
ipa_text = add_double_word_reductions(ipa_text, fixed_text)
ipa_text = handle_t_d(ipa_text)
ipa_text = ipa_text.replace("ˈ", "")
results.append((fixed_text, ipa_text))
return results

def print_ipa(out_file: Optional[TextIOWrapper], lines: List[str], fix_line_ends: bool = True, checkpoint_path: Optional[str] = None, start_line: int = 0):
global cached_text
total = len(lines)

pending_texts: List[str] = []
pending_indices: List[int] = []
newline_positions: List[Tuple[int, str]] = []

def flush_batch():
if not pending_texts:
return
batch_results = _run_flite_batch(pending_texts)
all_outputs = []
for pos_idx, marker in newline_positions:
all_outputs.append((pos_idx, marker, None))
for i, (orig, ipa) in enumerate(batch_results):
all_outputs.append((pending_indices[i], None, (orig, ipa)))
all_outputs.sort(key=lambda x: x[0])
for _, marker, result in all_outputs:
if marker is not None:
if out_file:
out_file.write(marker)
else:
print(marker, end='')
else:
orig, ipa = result
if out_file:
out_file.write(ipa)
out_file.write(orig)
else:
print((orig, ipa))
if out_file:
out_file.flush()
pending_texts.clear()
pending_indices.clear()
newline_positions.clear()

order_counter = 0
for i, line in enumerate(lines):
if i < start_line:
continue
Expand All @@ -368,33 +424,28 @@ def print_ipa(out_file: Optional[TextIOWrapper], lines: List[str], fix_line_ends
normalized_line = fix_line_ending(normalized_line)
if normalized_line is None:
continue
if out_file:
if normalized_line == "\n":
out_file.write(normalized_line)
out_file.flush()
continue
orig, ipa = run_flite(normalized_line)
out_file.write(ipa)
out_file.write(orig)
out_file.flush()
else:
print(run_flite(normalized_line))
if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0:
save_checkpoint(checkpoint_path, {
"lines_processed": i + 1,
"output_bytes": out_file.tell() if out_file else 0,
"cached_text": cached_text,
"line_end_count": line_end_count,
"is_chapter": is_chapter
})
if normalized_line == "\n":
newline_positions.append((order_counter, normalized_line))
order_counter += 1
continue
pending_texts.append(normalized_line)
pending_indices.append(order_counter)
order_counter += 1
if len(pending_texts) >= FLITE_BATCH_SIZE:
flush_batch()
if checkpoint_path:
save_checkpoint(checkpoint_path, {
"lines_processed": i + 1,
"output_bytes": out_file.tell() if out_file else 0,
"cached_text": cached_text,
"line_end_count": line_end_count,
"is_chapter": is_chapter
})
if cached_text != "":
if out_file:
orig, ipa = run_flite(cached_text)
out_file.write(ipa)
out_file.write(orig)
out_file.flush()
else:
print(run_flite(cached_text))
pending_texts.append(cached_text)
pending_indices.append(order_counter)
order_counter += 1
flush_batch()
if checkpoint_path:
save_checkpoint(checkpoint_path, {
"lines_processed": total,
Expand Down Expand Up @@ -424,38 +475,54 @@ def _decode_text_nodes(html_str: str) -> str:
parts[i] = _decode_html_text(part)
return ''.join(parts)

def _process_single_paragraph(match: re.Match, paragraph_count: int, counter: int) -> str:
def _prepare_paragraph_texts(match: re.Match):
open_tag = match.group(1)
inner = match.group(2)
close_tag = match.group(3)

plain_text = TAG_PATTERN.sub('', inner)
decoded_text = _decode_html_text(plain_text)
stripped = decoded_text.strip()

decoded_inner = _decode_text_nodes(inner)
if not (stripped and any(c.isalpha() for c in stripped)):
return (open_tag, close_tag, decoded_inner, None, []), []
parts = re.split(r'(<[^>]*>)', inner)
flite_needed = []
for i, part in enumerate(parts):
if not part.startswith('<'):
decoded_part = _decode_html_text(part)
if decoded_part.strip() and any(c.isalpha() for c in decoded_part):
flite_needed.append((i, normalize(decoded_part), decoded_part))
return (open_tag, close_tag, decoded_inner, parts, flite_needed), [n for _, n, _ in flite_needed]

def _assemble_paragraph(prep_data, flite_results, paragraph_count, counter):
open_tag, close_tag, decoded_inner, parts, flite_needed = prep_data
if parts is None:
return open_tag + decoded_inner + close_tag
result_map = {}
for idx, (part_i, _, decoded_part) in enumerate(flite_needed):
result_map[part_i] = (decoded_part, flite_results[idx])
ipa_parts = []
for i, part in enumerate(parts):
if part.startswith('<'):
ipa_parts.append(part)
elif i in result_map:
decoded_part, ipa = result_map[i]
leading = decoded_part[:len(decoded_part) - len(decoded_part.lstrip())]
trailing = decoded_part[len(decoded_part.rstrip()):]
ipa_parts.append(leading + ipa.strip() + trailing)
else:
ipa_parts.append(_decode_html_text(part))
ipa_inner = ''.join(ipa_parts)
print(f"paragraph {counter} / {paragraph_count}")
return open_tag + ipa_inner + close_tag + '\n' + open_tag + decoded_inner + close_tag

if stripped and any(c.isalpha() for c in stripped):
parts = re.split(r'(<[^>]*>)', inner)
ipa_parts = []
for part in parts:
if part.startswith('<'):
ipa_parts.append(part)
else:
decoded_part = _decode_html_text(part)
if decoded_part.strip() and any(c.isalpha() for c in decoded_part):
normalized = normalize(decoded_part)
_, ipa = run_flite(normalized)
leading = decoded_part[:len(decoded_part) - len(decoded_part.lstrip())]
trailing = decoded_part[len(decoded_part.rstrip()):]
ipa_parts.append(leading + ipa.strip() + trailing)
else:
ipa_parts.append(decoded_part)
ipa_inner = ''.join(ipa_parts)
print(f"paragraph {counter} / {paragraph_count}")
return open_tag + ipa_inner + close_tag + '\n' + open_tag + decoded_inner + close_tag

return open_tag + decoded_inner + close_tag
def _process_single_paragraph(match: re.Match, paragraph_count: int, counter: int) -> str:
prep_data, normalized_texts = _prepare_paragraph_texts(match)
flite_results = []
for text in normalized_texts:
_, ipa = run_flite(text)
flite_results.append(ipa)
return _assemble_paragraph(prep_data, flite_results, paragraph_count, counter)

def process_html_file(input_path: str, output_path: Optional[str], resume: bool = False):
with open(input_path, 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -487,19 +554,39 @@ def replace_paragraph(match):

mode = "a" if start_paragraph > 0 else "w"
out_file = open(output_path, mode, encoding='utf-8')
prev_end = 0

for idx, match in enumerate(matches):
if idx < start_paragraph:
prev_end = matches[start_paragraph - 1].end() if start_paragraph > 0 else 0

for batch_start in range(start_paragraph, len(matches), FLITE_BATCH_SIZE):
batch_end = min(batch_start + FLITE_BATCH_SIZE, len(matches))
batch_prep = []
all_normalized = []
text_counts = []
for idx in range(batch_start, batch_end):
prep_data, normalized_texts = _prepare_paragraph_texts(matches[idx])
batch_prep.append((idx, prep_data))
all_normalized.extend(normalized_texts)
text_counts.append(len(normalized_texts))
with ThreadPoolExecutor(max_workers=FLITE_MAX_WORKERS) as executor:
all_ipa_raw = list(executor.map(_call_flite, all_normalized))
all_flite_results = []
for raw_ipa, normalized in zip(all_ipa_raw, all_normalized):
ipa = add_reductions_with_stress(raw_ipa, normalized)
ipa = add_double_word_reductions(ipa, normalized)
ipa = handle_t_d(ipa)
ipa = ipa.replace("ˈ", "")
all_flite_results.append(ipa)
result_offset = 0
for (idx, prep_data), count in zip(batch_prep, text_counts):
flite_results = all_flite_results[result_offset:result_offset + count]
result_offset += count
match = matches[idx]
out_file.write(content[prev_end:match.start()])
out_file.write(_assemble_paragraph(prep_data, flite_results, paragraph_count, idx + 1))
out_file.flush()
prev_end = match.end()
continue
out_file.write(content[prev_end:match.start()])
out_file.write(_process_single_paragraph(match, paragraph_count, idx + 1))
out_file.flush()
prev_end = match.end()
if checkpoint_path:
save_checkpoint(checkpoint_path, {
"paragraphs_processed": idx + 1,
"paragraphs_processed": batch_end,
"output_bytes": out_file.tell()
})

Expand Down