diff --git a/main.py b/main.py index 1d335f6..165fab2 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import argparse +from concurrent.futures import ThreadPoolExecutor from io import TextIOWrapper import html as html_module import json @@ -6,7 +7,7 @@ import re import string import subprocess -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import unicodedata import os import nltk @@ -19,18 +20,13 @@ nltk.download('averaged_perceptron_tagger_eng', quiet=True) def is_verb_in_sentence(word, sentence): - # Tokenize the sentence into words tokens = word_tokenize(sentence) - - # Get POS tags for each word in the sentence tagged_words = pos_tag(tokens) - - # Find the POS tag of the target word in the tagged words list + word_lower = word.lower() for tagged_word, pos in tagged_words: - if tagged_word.lower() == word.lower(): # Case insensitive match - return pos.startswith('VB') # Check if the tag starts with 'VB' (verb) - - return False # Word not found in sentence + if tagged_word.lower() == word_lower: + return pos.startswith('VB') + return False ipa_vowels = "aeiouɑɒæɛɪʊʌɔœøɐɘəɤɨɵɜɞɯɲɳɴɶʉʊʏ" @@ -58,6 +54,14 @@ def is_verb_in_sentence(word, sentence): double_word_with_verb = {"could have": "cʊdə", "should have": "ʃʊdə", "would have": "wʊdə", "going to": "gɑnə"} # all noun+will can be reduced to x'll, but too hard for me to implement +_double_word_lookup: Dict[str, List[Tuple[str, str, bool]]] = {} +for _orig, _changed in double_word_reductions.items(): + _first, _second = _orig.split(" ") + _double_word_lookup.setdefault(_first, []).append((_second, _changed, False)) +for _orig, _changed in double_word_with_verb.items(): + _first, _second = _orig.split(" ") + _double_word_lookup.setdefault(_first, []).append((_second, _changed, True)) + # Most of those are not wrong, we just prefer it like this. These will be replced if appear in a word (good for plural and such): improved_pronounciations = { "fæməli":"fæmli", @@ -140,23 +144,21 @@ def add_double_word_reductions(ipa_text: str, original_text: str): removed_words = 0 for i in range(len(original_arr)): original_word = original_arr[i] - for orig, changed in list(double_word_reductions.items())+list(double_word_with_verb.items()): - first = orig.split(" ")[0] - if original_word == first: - second = orig.split(" ")[1] - if len(original_arr) > i + 1 and original_arr[i+1] == second: - # validate that this is not the last word. last word in sentence don't get reduced - next_char = get_next_char(original_arr, i+1, len(second)-1) - if next_char != "": - if second in ("will", "have", "has") and i + 2 < len(original_arr) and original_arr[i+2] == "not": + if original_word not in _double_word_lookup: + continue + for second, changed, needs_verb in _double_word_lookup[original_word]: + if len(original_arr) > i + 1 and original_arr[i+1] == second: + next_char = get_next_char(original_arr, i+1, len(second)-1) + if next_char != "": + if second in ("will", "have", "has") and i + 2 < len(original_arr) and original_arr[i+2] == "not": + continue + if needs_verb: + if not is_verb_in_sentence(original_arr[i+2], original_text): continue - if orig in double_word_with_verb: - if not is_verb_in_sentence(original_arr[i+2], original_text): - continue - out_arr[i - removed_words] = changed - del out_arr[i - removed_words + 1] - removed_words += 1 + out_arr[i - removed_words] = changed + del out_arr[i - removed_words + 1] + removed_words += 1 return " ".join(out_arr) def handle_t_d(ipa_text: str): @@ -271,18 +273,22 @@ def fix_numbers(text_arr: List[str]): -def run_flite(text: str): - fixed_text = text - # fixed_text = " ".join(fix_numbers(fix_nn(text.lower()))) +_flite_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flite', 'bin', 'flite') + +def _call_flite(text: str) -> str: try: - flite_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'flite', 'bin', 'flite') - ipa_text = subprocess.check_output([flite_path, "-t", fixed_text, "-i"]).decode('utf-8') + return subprocess.check_output([_flite_path, "-t", text, "-i"]).decode('utf-8') except OSError: logging.warning('lex_lookup (from flite) is not installed.') - ipa_text = '' + return '' except subprocess.CalledProcessError: logging.warning('Non-zero exit status from lex_lookup.') - ipa_text = '' + return '' + +def run_flite(text: str): + fixed_text = text + # fixed_text = " ".join(fix_numbers(fix_nn(text.lower()))) + ipa_text = _call_flite(fixed_text) ipa_text = add_reductions_with_stress(ipa_text, fixed_text) ipa_text = add_double_word_reductions(ipa_text, fixed_text) @@ -357,9 +363,59 @@ def remove_checkpoint(checkpoint_path): if os.path.exists(checkpoint_path): os.remove(checkpoint_path) +FLITE_BATCH_SIZE = 32 +FLITE_MAX_WORKERS = 8 + +def _run_flite_batch(texts: List[str]) -> List[Tuple[str, str]]: + with ThreadPoolExecutor(max_workers=FLITE_MAX_WORKERS) as executor: + ipa_results = list(executor.map(_call_flite, texts)) + results = [] + for fixed_text, ipa_text in zip(texts, ipa_results): + ipa_text = add_reductions_with_stress(ipa_text, fixed_text) + ipa_text = add_double_word_reductions(ipa_text, fixed_text) + ipa_text = handle_t_d(ipa_text) + ipa_text = ipa_text.replace("ˈ", "") + results.append((fixed_text, ipa_text)) + return results + def print_ipa(out_file: Optional[TextIOWrapper], lines: List[str], fix_line_ends: bool = True, checkpoint_path: Optional[str] = None, start_line: int = 0): global cached_text total = len(lines) + + pending_texts: List[str] = [] + pending_indices: List[int] = [] + newline_positions: List[Tuple[int, str]] = [] + + def flush_batch(): + if not pending_texts: + return + batch_results = _run_flite_batch(pending_texts) + all_outputs = [] + for pos_idx, marker in newline_positions: + all_outputs.append((pos_idx, marker, None)) + for i, (orig, ipa) in enumerate(batch_results): + all_outputs.append((pending_indices[i], None, (orig, ipa))) + all_outputs.sort(key=lambda x: x[0]) + for _, marker, result in all_outputs: + if marker is not None: + if out_file: + out_file.write(marker) + else: + print(marker, end='') + else: + orig, ipa = result + if out_file: + out_file.write(ipa) + out_file.write(orig) + else: + print((orig, ipa)) + if out_file: + out_file.flush() + pending_texts.clear() + pending_indices.clear() + newline_positions.clear() + + order_counter = 0 for i, line in enumerate(lines): if i < start_line: continue @@ -368,33 +424,28 @@ def print_ipa(out_file: Optional[TextIOWrapper], lines: List[str], fix_line_ends normalized_line = fix_line_ending(normalized_line) if normalized_line is None: continue - if out_file: - if normalized_line == "\n": - out_file.write(normalized_line) - out_file.flush() - continue - orig, ipa = run_flite(normalized_line) - out_file.write(ipa) - out_file.write(orig) - out_file.flush() - else: - print(run_flite(normalized_line)) - if checkpoint_path and (i + 1) % CHECKPOINT_INTERVAL == 0: - save_checkpoint(checkpoint_path, { - "lines_processed": i + 1, - "output_bytes": out_file.tell() if out_file else 0, - "cached_text": cached_text, - "line_end_count": line_end_count, - "is_chapter": is_chapter - }) + if normalized_line == "\n": + newline_positions.append((order_counter, normalized_line)) + order_counter += 1 + continue + pending_texts.append(normalized_line) + pending_indices.append(order_counter) + order_counter += 1 + if len(pending_texts) >= FLITE_BATCH_SIZE: + flush_batch() + if checkpoint_path: + save_checkpoint(checkpoint_path, { + "lines_processed": i + 1, + "output_bytes": out_file.tell() if out_file else 0, + "cached_text": cached_text, + "line_end_count": line_end_count, + "is_chapter": is_chapter + }) if cached_text != "": - if out_file: - orig, ipa = run_flite(cached_text) - out_file.write(ipa) - out_file.write(orig) - out_file.flush() - else: - print(run_flite(cached_text)) + pending_texts.append(cached_text) + pending_indices.append(order_counter) + order_counter += 1 + flush_batch() if checkpoint_path: save_checkpoint(checkpoint_path, { "lines_processed": total, @@ -424,38 +475,54 @@ def _decode_text_nodes(html_str: str) -> str: parts[i] = _decode_html_text(part) return ''.join(parts) -def _process_single_paragraph(match: re.Match, paragraph_count: int, counter: int) -> str: +def _prepare_paragraph_texts(match: re.Match): open_tag = match.group(1) inner = match.group(2) close_tag = match.group(3) - plain_text = TAG_PATTERN.sub('', inner) decoded_text = _decode_html_text(plain_text) stripped = decoded_text.strip() - decoded_inner = _decode_text_nodes(inner) + if not (stripped and any(c.isalpha() for c in stripped)): + return (open_tag, close_tag, decoded_inner, None, []), [] + parts = re.split(r'(<[^>]*>)', inner) + flite_needed = [] + for i, part in enumerate(parts): + if not part.startswith('<'): + decoded_part = _decode_html_text(part) + if decoded_part.strip() and any(c.isalpha() for c in decoded_part): + flite_needed.append((i, normalize(decoded_part), decoded_part)) + return (open_tag, close_tag, decoded_inner, parts, flite_needed), [n for _, n, _ in flite_needed] + +def _assemble_paragraph(prep_data, flite_results, paragraph_count, counter): + open_tag, close_tag, decoded_inner, parts, flite_needed = prep_data + if parts is None: + return open_tag + decoded_inner + close_tag + result_map = {} + for idx, (part_i, _, decoded_part) in enumerate(flite_needed): + result_map[part_i] = (decoded_part, flite_results[idx]) + ipa_parts = [] + for i, part in enumerate(parts): + if part.startswith('<'): + ipa_parts.append(part) + elif i in result_map: + decoded_part, ipa = result_map[i] + leading = decoded_part[:len(decoded_part) - len(decoded_part.lstrip())] + trailing = decoded_part[len(decoded_part.rstrip()):] + ipa_parts.append(leading + ipa.strip() + trailing) + else: + ipa_parts.append(_decode_html_text(part)) + ipa_inner = ''.join(ipa_parts) + print(f"paragraph {counter} / {paragraph_count}") + return open_tag + ipa_inner + close_tag + '\n' + open_tag + decoded_inner + close_tag - if stripped and any(c.isalpha() for c in stripped): - parts = re.split(r'(<[^>]*>)', inner) - ipa_parts = [] - for part in parts: - if part.startswith('<'): - ipa_parts.append(part) - else: - decoded_part = _decode_html_text(part) - if decoded_part.strip() and any(c.isalpha() for c in decoded_part): - normalized = normalize(decoded_part) - _, ipa = run_flite(normalized) - leading = decoded_part[:len(decoded_part) - len(decoded_part.lstrip())] - trailing = decoded_part[len(decoded_part.rstrip()):] - ipa_parts.append(leading + ipa.strip() + trailing) - else: - ipa_parts.append(decoded_part) - ipa_inner = ''.join(ipa_parts) - print(f"paragraph {counter} / {paragraph_count}") - return open_tag + ipa_inner + close_tag + '\n' + open_tag + decoded_inner + close_tag - - return open_tag + decoded_inner + close_tag +def _process_single_paragraph(match: re.Match, paragraph_count: int, counter: int) -> str: + prep_data, normalized_texts = _prepare_paragraph_texts(match) + flite_results = [] + for text in normalized_texts: + _, ipa = run_flite(text) + flite_results.append(ipa) + return _assemble_paragraph(prep_data, flite_results, paragraph_count, counter) def process_html_file(input_path: str, output_path: Optional[str], resume: bool = False): with open(input_path, 'r', encoding='utf-8') as f: @@ -487,19 +554,39 @@ def replace_paragraph(match): mode = "a" if start_paragraph > 0 else "w" out_file = open(output_path, mode, encoding='utf-8') - prev_end = 0 - - for idx, match in enumerate(matches): - if idx < start_paragraph: + prev_end = matches[start_paragraph - 1].end() if start_paragraph > 0 else 0 + + for batch_start in range(start_paragraph, len(matches), FLITE_BATCH_SIZE): + batch_end = min(batch_start + FLITE_BATCH_SIZE, len(matches)) + batch_prep = [] + all_normalized = [] + text_counts = [] + for idx in range(batch_start, batch_end): + prep_data, normalized_texts = _prepare_paragraph_texts(matches[idx]) + batch_prep.append((idx, prep_data)) + all_normalized.extend(normalized_texts) + text_counts.append(len(normalized_texts)) + with ThreadPoolExecutor(max_workers=FLITE_MAX_WORKERS) as executor: + all_ipa_raw = list(executor.map(_call_flite, all_normalized)) + all_flite_results = [] + for raw_ipa, normalized in zip(all_ipa_raw, all_normalized): + ipa = add_reductions_with_stress(raw_ipa, normalized) + ipa = add_double_word_reductions(ipa, normalized) + ipa = handle_t_d(ipa) + ipa = ipa.replace("ˈ", "") + all_flite_results.append(ipa) + result_offset = 0 + for (idx, prep_data), count in zip(batch_prep, text_counts): + flite_results = all_flite_results[result_offset:result_offset + count] + result_offset += count + match = matches[idx] + out_file.write(content[prev_end:match.start()]) + out_file.write(_assemble_paragraph(prep_data, flite_results, paragraph_count, idx + 1)) + out_file.flush() prev_end = match.end() - continue - out_file.write(content[prev_end:match.start()]) - out_file.write(_process_single_paragraph(match, paragraph_count, idx + 1)) - out_file.flush() - prev_end = match.end() if checkpoint_path: save_checkpoint(checkpoint_path, { - "paragraphs_processed": idx + 1, + "paragraphs_processed": batch_end, "output_bytes": out_file.tell() })