diff --git a/main.py b/main.py
index c4867aa..52ca3c1 100644
--- a/main.py
+++ b/main.py
@@ -9,6 +9,7 @@
import re
import string
import subprocess
+import sys
from typing import Dict, List, Optional, Tuple
import unicodedata
import os
@@ -455,6 +456,20 @@ def flush_batch():
PARAGRAPH_PATTERN = re.compile(r'(
]*>)(.*?)(
)', re.DOTALL | re.IGNORECASE)
TAG_PATTERN = re.compile(r'<[^>]*>')
+SKIP_TAGS = {'script', 'style', 'head', 'noscript', 'svg', 'nav', 'footer'}
+SKIP_TAG_PATTERN = re.compile(
+ r'<(?P' + '|'.join(SKIP_TAGS) + r')\b[^>]*>.*?(?P=tag)>',
+ re.DOTALL | re.IGNORECASE
+)
+
+def _get_skip_ranges(content: str) -> List[Tuple[int, int]]:
+ return [(m.start(), m.end()) for m in SKIP_TAG_PATTERN.finditer(content)]
+
+def _in_skip_range(pos: int, skip_ranges: List[Tuple[int, int]]) -> bool:
+ for start, end in skip_ranges:
+ if start <= pos < end:
+ return True
+ return False
def _decode_html_text(text: str) -> str:
decoded = html_module.unescape(text)
@@ -524,7 +539,8 @@ def process_html_file(input_path: str, output_path: Optional[str], resume: bool
with open(input_path, 'r', encoding='utf-8') as f:
content = f.read()
- matches = list(PARAGRAPH_PATTERN.finditer(content))
+ skip_ranges = _get_skip_ranges(content)
+ matches = [m for m in PARAGRAPH_PATTERN.finditer(content) if not _in_skip_range(m.start(), skip_ranges)]
paragraph_count = len(matches)
checkpoint_path = get_checkpoint_path(output_path) if output_path else None
@@ -540,16 +556,11 @@ def process_html_file(input_path: str, output_path: Optional[str], resume: bool
with open(output_path, "r+b") as f:
f.truncate(output_bytes)
- if not output_path:
- counter = [0]
- def replace_paragraph(match):
- counter[0] += 1
- return _process_single_paragraph(match, paragraph_count, counter[0])
- print(PARAGRAPH_PATTERN.sub(replace_paragraph, content))
- return
-
- mode = "a" if start_paragraph > 0 else "w"
- out_file = open(output_path, mode, encoding='utf-8')
+ if output_path:
+ mode = "a" if start_paragraph > 0 else "w"
+ out_file = open(output_path, mode, encoding='utf-8')
+ else:
+ out_file = sys.stdout
prev_end = matches[start_paragraph - 1].end() if start_paragraph > 0 else 0
for batch_start in range(start_paragraph, len(matches), FLITE_BATCH_SIZE):
@@ -588,7 +599,8 @@ def replace_paragraph(match):
out_file.write(content[prev_end:])
out_file.flush()
- out_file.close()
+ if output_path:
+ out_file.close()
if checkpoint_path:
remove_checkpoint(checkpoint_path)
@@ -671,3 +683,4 @@ def main():
if __name__ == "__main__":
main()
+
\ No newline at end of file