diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 00000000..2d2ecd68
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1 @@
+.git/
diff --git a/Dockerfile b/Dockerfile
index 67102bdc..acf429c0 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -32,9 +32,6 @@ RUN git clone https://github.com/Huanshere/VideoLingo.git .
# Install PyTorch and torchaudio
RUN pip install torch==2.0.0 torchaudio==2.0.0 --index-url https://download.pytorch.org/whl/cu118
-# Clean up unnecessary files
-RUN rm -rf .git
-
# Upgrade pip and install basic dependencies
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
pip install --no-cache-dir --upgrade pip requests rich ruamel.yaml
diff --git a/OneKeyInstall&Start.bat b/OneKeyInstall&Start.bat
index d1a4e77f..dcf0336c 100644
--- a/OneKeyInstall&Start.bat
+++ b/OneKeyInstall&Start.bat
@@ -2,6 +2,26 @@
cd /D "%~dp0"
+set INSTALL_ENV_DIR=%cd%\installer_files\env
+set CONDA_ROOT_PREFIX=%cd%\installer_files\conda
+
+set PYTHONNOUSERSITE=1
+set PYTHONPATH=
+set PYTHONHOME=
+set "CUDA_PATH=%INSTALL_ENV_DIR%"
+set "CUDA_HOME=%CUDA_PATH%"
+
+@rem Check if conda environment exists
+if exist "%INSTALL_ENV_DIR%\python.exe" (
+ echo Conda environment found, starting directly...
+ echo If startup fails, please delete the 'installer_files' folder and reinstall.
+ @rem Activate environment
+ call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniconda hook not found. && goto end )
+ python -m streamlit run st.py
+ goto end
+)
+
+@rem Original installation path continues...
set PATH=%PATH%;%SystemRoot%\system32
echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniconda which can not be silently installed under a path with spaces. && goto end
@@ -40,28 +60,17 @@ if "%conda_exists%" == "F" (
@rem create the installer env
if not exist "%INSTALL_ENV_DIR%" (
echo Packages to install: python=3.10.0 requests rich ruamel.yaml
- call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10.0 requests rich "ruamel.yaml" || ( echo. && echo Conda environment creation failed. && goto end )
+ call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.10.0 requests rich ruamel.yaml || ( echo. && echo Conda environment creation failed. && goto end )
)
@rem check if conda environment was actually created
if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo Conda environment is empty. && goto end )
-@rem environment isolation
-set PYTHONNOUSERSITE=1
-set PYTHONPATH=
-set PYTHONHOME=
-@rem ! may cause error if we use cudnn on windows
-set "CUDA_PATH=%INSTALL_ENV_DIR%"
-set "CUDA_HOME=%CUDA_PATH%"
-
-@rem activate installer env
-call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniconda hook not found. && goto end )
-
-@rem Run pip setup
+:start
call python pip_setup.py
echo.
-echo Done!
+echo ✅ Done!
:end
pause
diff --git a/OneKeyStart.bat b/OneKeyStart.bat
deleted file mode 100644
index 26ea90d4..00000000
--- a/OneKeyStart.bat
+++ /dev/null
@@ -1,30 +0,0 @@
-@echo off
-cd /D "%~dp0"
-
-set INSTALL_ENV_DIR=%cd%\installer_files\env
-set CONDA_ROOT_PREFIX=%cd%\installer_files\conda
-
-@rem Check if conda environment exists
-if not exist "%INSTALL_ENV_DIR%\python.exe" (
- echo Conda environment not found!
- echo Please run OneKeyInstall^&Start.bat first to set up the environment.
- goto end
-)
-
-@rem Environment isolation
-set PYTHONNOUSERSITE=1
-set PYTHONPATH=
-set PYTHONHOME=
-set "CUDA_PATH=%INSTALL_ENV_DIR%"
-set "CUDA_HOME=%CUDA_PATH%"
-
-@rem Activate conda environment and run streamlit
-call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" && (
- python -m streamlit run st.py
-) || (
- echo Failed to activate conda environment!
- echo Please run OneKeyInstall^&Start.bat to reinstall the environment.
-)
-
-:end
-pause
\ No newline at end of file
diff --git a/README.md b/README.md
index 6572bc94..24231920 100644
--- a/README.md
+++ b/README.md
@@ -68,9 +68,14 @@ https://github.com/user-attachments/assets/47d965b2-b4ab-4a0b-9d08-b49a7bf3508c
## Installation
-Windows users can double-click `OneKeyInstall&Start.bat` to install (requires Git). The script will download Miniconda and install the complete environment. For NVIDIA GPU users, you need to first install [CUDA 12.6](https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe) and [CUDNN 9.3.0](https://developer.download.nvidia.com/compute/cudnn/9.3.0/local_installers/cudnn_9.3.0_windows.exe), then add `C:\Program Files\NVIDIA\CUDNN\v9.3\bin\12.6` to system environment variables and restart.
+### Windows
+Simply double-click `OneKeyInstall&Start.bat` to get started. The script will:
+- Download and install Miniconda automatically
+- Install all required dependencies for both GPU and CPU
-MacOS/Linux users should install from source, requiring `python=3.10.0` environment.
+Prerequisites: Git must be installed on your system.
+
+### macOS/Linux
1. Clone the repository
@@ -79,10 +84,10 @@ git clone https://github.com/Huanshere/VideoLingo.git
cd VideoLingo
```
-2. Install dependencies
+2. Install dependencies(requires `python=3.10.0`)
```bash
-conda create -n videolingo python=3.10.0
+conda create -n videolingo python=3.10.0 -y
conda activate videolingo
python install.py
```
@@ -93,6 +98,7 @@ python install.py
streamlit run st.py
```
+### Docker
Alternatively, you can use Docker (requires CUDA 12.4 and NVIDIA Driver version >550), see [Docker docs](/docs/pages/docs/docker.en-US.md):
```bash
diff --git a/core/step2_whisperX.py b/core/step2_whisperX.py
index 064b7750..5ce68212 100644
--- a/core/step2_whisperX.py
+++ b/core/step2_whisperX.py
@@ -50,13 +50,24 @@ def check_hf_mirror() -> str:
rprint("[yellow]⚠️ All mirrors failed, using default[/yellow]")
rprint(f"[cyan]🚀 Selected mirror:[/cyan] {fastest_url} ({best_time:.2f}s)")
return fastest_url
-
def transcribe_audio(audio_file: str, start: float, end: float) -> Dict:
+ """
+ 使用WhisperX模型对音频文件进行转录。
+
+ 参数:
+ audio_file (str): 要转录的音频文件路径。
+ start (float): 音频片段的起始时间(以秒为单位)。
+ end (float): 音频片段的结束时间(以秒为单位)。
+
+ 返回:
+ Dict: 包含转录结果的字典,包括文本、时间戳等信息。
+ """
os.environ['HF_ENDPOINT'] = check_hf_mirror() #? don't know if it's working...
WHISPER_LANGUAGE = load_key("whisper.language")
device = "cuda" if torch.cuda.is_available() else "cpu"
rprint(f"🚀 Starting WhisperX using device: {device} ...")
-
+
+ # 根据GPU内存设置批处理大小和计算类型
if device == "cuda":
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
batch_size = 16 if gpu_mem > 8 else 2
@@ -67,52 +78,59 @@ def transcribe_audio(audio_file: str, start: float, end: float) -> Dict:
compute_type = "int8"
rprint(f"[cyan]📦 Batch size:[/cyan] {batch_size}, [cyan]⚙️ Compute type:[/cyan] {compute_type}")
rprint(f"[green]▶️ Starting WhisperX for segment {start:.2f}s to {end:.2f}s...[/green]")
-
+
try:
+ # 根据语言选择Whisper模型
if WHISPER_LANGUAGE == 'zh':
model_name = "Huan69/Belle-whisper-large-v3-zh-punct-fasterwhisper"
local_model = os.path.join(MODEL_DIR, "Belle-whisper-large-v3-zh-punct-fasterwhisper")
else:
model_name = load_key("whisper.model")
local_model = os.path.join(MODEL_DIR, model_name)
-
+
+ # 加载本地或远程的Whisper模型
if os.path.exists(local_model):
rprint(f"[green]📥 Loading local WHISPER model:[/green] {local_model} ...")
model_name = local_model
else:
rprint(f"[green]📥 Using WHISPER model from HuggingFace:[/green] {model_name} ...")
+ # 设置VAD和ASR选项
vad_options = {"vad_onset": 0.500,"vad_offset": 0.363}
asr_options = {"temperatures": [0],"initial_prompt": "",}
whisper_language = None if 'auto' in WHISPER_LANGUAGE else WHISPER_LANGUAGE
rprint("[bold yellow]**You can ignore warning of `Model was trained with torch 1.10.0+cu102, yours is 2.0.0+cu118...`**[/bold yellow]")
model = whisperx.load_model(model_name, device, compute_type=compute_type, language=whisper_language, vad_options=vad_options, asr_options=asr_options, download_root=MODEL_DIR)
- # Create temporary file to store audio segment
- temp_audio = tempfile.NamedTemporaryFile(suffix='.mp3', delete=False)
- temp_audio_path = temp_audio.name
- temp_audio.close()
- # Use ffmpeg to cut audio
- ffmpeg_cmd = f'ffmpeg -y -i "{audio_file}" -ss {start} -t {end-start} -vn -b:a 64k -ar 16000 -ac 1 -metadata encoding=UTF-8 -f mp3 "{temp_audio_path}"'
+ # 创建临时WAV文件以提高兼容性
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_audio:
+ temp_audio_path = temp_audio.name
+
+ # 使用ffmpeg提取音频片段
+ ffmpeg_cmd = f'ffmpeg -y -i "{audio_file}" -ss {start} -t {end-start} -vn -ar 16000 -ac 1 "{temp_audio_path}"'
subprocess.run(ffmpeg_cmd, shell=True, check=True, capture_output=True)
- # Load the cut audio
- audio_segment, sample_rate = librosa.load(temp_audio_path, sr=16000)
- # Delete temporary file
- os.unlink(temp_audio_path)
+
+ try:
+ # 使用librosa加载音频片段
+ audio_segment, sample_rate = librosa.load(temp_audio_path, sr=16000)
+ finally:
+ # 清理临时文件
+ if os.path.exists(temp_audio_path):
+ os.unlink(temp_audio_path)
rprint("[bold green]note: You will see Progress if working correctly[/bold green]")
result = model.transcribe(audio_segment, batch_size=batch_size, print_progress=True)
- # Free GPU resources
+ # 释放GPU资源
del model
torch.cuda.empty_cache()
- # Save language
+ # 保存语言信息并检查是否与指定语言一致
save_language(result['language'])
if result['language'] == 'zh' and WHISPER_LANGUAGE != 'zh':
- raise ValueError("请指定转录语言为 zh 后重试!")
+ raise ValueError("Please specify the transcription language as zh and try again!")
- # Align whisper output
+ # 对齐Whisper输出
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio_segment, device, return_char_alignments=False)
diff --git a/docs/pages/docs/start.en-US.md b/docs/pages/docs/start.en-US.md
index 24c68765..1687ff45 100644
--- a/docs/pages/docs/start.en-US.md
+++ b/docs/pages/docs/start.en-US.md
@@ -114,13 +114,6 @@ After configuration, select `Reference Audio Mode` in the sidebar (see Yuque doc
VideoLingo supports Windows, macOS and Linux systems, and can run on CPU or GPU.
-For GPU acceleration on Windows, install these dependencies:
-
-- [CUDA Toolkit 12.6](https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe)
-- [CUDNN 9.3.0](https://developer.download.nvidia.com/compute/cudnn/9.3.0/local_installers/cudnn_9.3.0_windows.exe)
-
-> Note: After installing, add `C:\Program Files\NVIDIA\CUDNN\v9.3\bin\12.6` to system path and restart computer 🔄
-
### Windows One-Click Install
Make sure [Git](https://git-scm.com/downloads) is installed,
@@ -131,48 +124,37 @@ Make sure [Git](https://git-scm.com/downloads) is installed,
### Source Installation
-Before installing VideoLingo, ensure:
-1. **25GB** free disk space
-2. [Anaconda](https://www.anaconda.com/download) installed (for Python environment management)
-3. [Git](https://git-scm.com/downloads) installed (for cloning project code, or download manually)
+Before installing VideoLingo, ensure you have **25GB** free disk space and installed Git and Anaconda.
-Basic Python knowledge required. For any issues, ask the AI assistant at [videolingo.io](https://videolingo.io) bottom right~
-1. Open Anaconda Prompt and navigate to installation directory, e.g. desktop:
- ```bash
- cd desktop
- ```
-
-2. Clone project and enter directory:
+1. Clone the project:
```bash
git clone https://github.com/Huanshere/VideoLingo.git
cd VideoLingo
```
-3. Create and activate virtual environment (**must be 3.10.0**):
+2. Create and activate virtual environment (**must be python=3.10.0**):
```bash
conda create -n videolingo python=3.10.0 -y
conda activate videolingo
```
-4. Run installation script:
+3. Run installation script:
```bash
python install.py
```
Script will automatically install appropriate torch version
-5. 🎉 Enter command to launch Streamlit app:
+4. 🎉 Launch Streamlit app:
```bash
streamlit run st.py
```
-6. Set key in sidebar of popup webpage and start using~
+5. Set key in sidebar of popup webpage and start using~

-7. Transcription step will automatically download models from huggingface, or you can download manually and place `_model_cache` folder in VideoLingo directory: [Baidu Drive](https://pan.baidu.com/s/1Igo_FvFV4Xcb8tSYT0ktpA?pwd=e1c7)
-
-8. (Optional) More settings can be manually modified in `config.yaml`, watch command line output during operation
+6. (Optional) More settings can be manually modified in `config.yaml`, watch command line output during operation
## 🏭 Batch Mode (beta)
diff --git a/docs/pages/docs/start.zh-CN.md b/docs/pages/docs/start.zh-CN.md
index 04e76a77..422c46d3 100644
--- a/docs/pages/docs/start.zh-CN.md
+++ b/docs/pages/docs/start.zh-CN.md
@@ -115,13 +115,6 @@ VideoLingo提供了多种 tts 接入方式,以下是对比(如不使用配
VideoLingo 支持 Windows、macOS 和 Linux 系统,可使用 CPU 或 GPU 运行。
-对于 Windows 系统使用 GPU 加速,需要安装以下依赖:
-
-- [CUDA Toolkit 12.6](https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe)
-- [CUDNN 9.3.0](https://developer.download.nvidia.com/compute/cudnn/9.3.0/local_installers/cudnn_9.3.0_windows.exe)
-
-> 注意:安装后需要将 `C:\Program Files\NVIDIA\CUDNN\v9.3\bin\12.6` 添加至系统环境变量,并重启计算机 🔄
-
### Windows 一键安装
请确保已安装 [Git](https://git-scm.com/downloads),
@@ -135,56 +128,44 @@ VideoLingo 支持 Windows、macOS 和 Linux 系统,可使用 CPU 或 GPU 运
3. 双击 `OneKeyInstall&Start.bat` 即可完成安装并启动网页
-### 源码安装
-
-开始安装 VideoLingo 之前,请确保:
-1. 预留 **25G** 硬盘空间
-2. 已安装 [Anaconda](https://www.anaconda.com/download) (用于 Python 环境管理)
-3. 已安装 [Git](https://git-scm.com/downloads) (用于克隆项目代码,也可以手动下载)
+### macOS/Linux 源码安装
-需要一定的 python 基础,遇到任何问题可以询问官方网站 [videolingo.io](https://videolingo.io) 右下角的AI助手~
+开始安装 VideoLingo 之前,请确保预留 **25G** 硬盘空间,并安装了 Git 和 Anaconda。
-1. 打开 `Anaconda Prompt` 并切换到你想安装的目录,例如桌面:
- ```bash
- cd desktop
- ```
-
-2. 克隆项目并切换至项目目录:
+1. 克隆项目:
```bash
git clone https://github.com/Huanshere/VideoLingo.git
cd VideoLingo
```
-3. 创建并激活虚拟环境(**必须 3.10.0**):
+2. 创建并激活虚拟环境(**必须 3.10.0**):
```bash
conda create -n videolingo python=3.10.0 -y
conda activate videolingo
```
-4. (可选)应用汉化补丁:
+3. (可选)应用汉化补丁:
参照 **一键安装** 中的说明
(注意:Mac系统会删除整个目标文件夹后再复制,而Windows只会替换重复的文件。Mac用户建议手动将文件逐个移动到目标位置)
-5. 运行安装脚本:
+4. 运行安装脚本:
```bash
python install.py
```
脚本将自动安装相应的 torch 版本
-6. 🎉 输入命令或点击 `一键启动.bat` 启动 Streamlit 应用:
+5. 🎉 输入命令或点击 `一键启动.bat` 启动 Streamlit 应用:
```bash
streamlit run st.py
```
-7. 在弹出网页的侧边栏中设置key,开始使用~
+6. 在弹出网页的侧边栏中设置key,开始使用~

-8. 转录步骤会自动从 huggingface 下载模型,也可以手动下载,将 `_model_cache` 文件夹放置在 VideoLingo 目录下:[百度网盘](https://pan.baidu.com/s/1Igo_FvFV4Xcb8tSYT0ktpA?pwd=e1c7)
-
-9. (可选)更多设置可以在 `config.yaml` 中手动修改,运行过程请注意命令行输出
+7. (可选)更多设置可以在 `config.yaml` 中手动修改,运行过程请注意命令行输出
## 🏭 批量模式(beta)
diff --git a/i18n/README.zh.md b/i18n/README.zh.md
index 486f915c..d479b69d 100644
--- a/i18n/README.zh.md
+++ b/i18n/README.zh.md
@@ -70,32 +70,38 @@ https://github.com/user-attachments/assets/47d965b2-b4ab-4a0b-9d08-b49a7bf3508c
## 安装
-Windows 用户可以双击运行 `OneKeyInstall&Start.bat` 一键安装(确保 Git 已安装),该脚本会下载 Miniconda 并安装完整环境。对于使用 NVIDIA GPU 的用户,还需要先安装 [CUDA 12.6](https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda_12.6.0_560.76_windows.exe) 和 [CUDNN 9.3.0](https://developer.download.nvidia.com/compute/cudnn/9.3.0/local_installers/cudnn_9.3.0_windows.exe),并在安装完成后添加 `C:\Program Files\NVIDIA\CUDNN\v9.3\bin\12.6` 到系统环境变量并重启。
+### Windows
+直接双击运行 `OneKeyInstall&Start.bat` 即可开始安装。该脚本会:
+- 自动下载并安装 Miniconda
+- 安装所有 GPU 或 CPU 所需的依赖
-MacOS/Linux 用户请从源码安装,需要 `python=3.10.0` 环境。
+前置要求:系统需已安装 Git。
-1. 下载项目代码
+### macOS/Linux
+
+1. 克隆仓库
```bash
git clone https://github.com/Huanshere/VideoLingo.git
cd VideoLingo
```
-2. 安装依赖
+2. 安装依赖(需要 `python=3.10.0`)
```bash
-conda create -n videolingo python=3.10.0
+conda create -n videolingo python=3.10.0 -y
conda activate videolingo
python install.py
```
-3. 启动
+3. 启动应用
```bash
streamlit run st.py
```
-还可以选择使用 Docker,要求CUDA版本为12.4,NVIDIA Driver版本大于550,详见[Docker文档](/docs/pages/docs/docker.zh-CN.md):
+### Docker
+还可以选择使用 Docker(要求 CUDA 12.4 和 NVIDIA Driver 版本 >550),详见[Docker文档](/docs/pages/docs/docker.zh-CN.md):
```bash
docker build -t videolingo .
diff --git "a/i18n/\344\270\255\346\226\207/install.py" "b/i18n/\344\270\255\346\226\207/install.py"
index 6e05ad73..f89bec73 100644
--- "a/i18n/\344\270\255\346\226\207/install.py"
+++ "b/i18n/\344\270\255\346\226\207/install.py"
@@ -2,8 +2,6 @@
import platform
import subprocess
import sys
-import zipfile
-import shutil
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
ascii_logo = """
@@ -49,123 +47,60 @@ def main():
choose_mirror()
# 检测系统和GPU
- if platform.system() == 'Darwin':
- console.print(Panel("🍎 检测到 MacOS,正在安装 CPU 版本的 PyTorch... 但转写速度会慢很多", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
+ has_gpu = platform.system() != 'Darwin' and check_gpu()
+ if has_gpu:
+ console.print(Panel("🎮 检测到 NVIDIA GPU,正在安装 CUDA 版本的 PyTorch...", style="cyan"))
+ subprocess.check_call(["conda", "install", "-y", "pytorch==2.0.0", "torchaudio==2.0.0", "pytorch-cuda=11.8", "-c", "pytorch", "-c", "nvidia"])
else:
- has_gpu = check_gpu()
- if has_gpu:
- console.print(Panel("🎮 检测到 NVIDIA GPU,正在安装 CUDA 版本的 PyTorch...", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.0", "torchaudio==2.0.0", "--index-url", "https://download.pytorch.org/whl/cu118"])
- else:
- console.print(Panel("💻 未检测到 NVIDIA GPU,正在安装 CPU 版本的 PyTorch... 但转写速度会慢很多", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
-
- # 安装 WhisperX
- console.print(Panel("📦 正在安装 WhisperX...", style="cyan"))
- current_dir = os.getcwd()
- whisperx_dir = os.path.join(current_dir, "third_party", "whisperX")
- os.chdir(whisperx_dir)
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])
- os.chdir(current_dir)
+ system_name = "🍎 MacOS" if platform.system() == 'Darwin' else "💻 未检测到 NVIDIA GPU"
+ console.print(Panel(f"{system_name},正在安装 CPU 版本的 PyTorch... 但转写速度会慢很多", style="cyan"))
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
def install_requirements():
try:
- with open("requirements.txt", "r", encoding="utf-8") as file:
- content = file.read()
- with open("requirements.txt", "w", encoding="gbk") as file:
- file.write(content)
- except Exception as e:
- print(f"转换 requirements.txt 时出错: {str(e)}")
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
+ subprocess.check_call([
+ sys.executable,
+ "-m",
+ "pip",
+ "install",
+ "-r",
+ "requirements.txt"
+ ], env={**os.environ, "PIP_NO_CACHE_DIR": "0", "PYTHONIOENCODING": "utf-8"})
+ except subprocess.CalledProcessError as e:
+ console.print(Panel(f"❌ 安装依赖失败: {str(e)}", style="red"))
- def download_and_extract_ffmpeg():
- # VL requires both conda/system ffmpeg and ffmpeg.exe...
- system = platform.system()
- if system == "Linux":
- # Linux: use apt or yum to install ffmpeg
- try:
- console.print(Panel("📦 正在通过 apt 安装 ffmpeg...", style="cyan"))
- subprocess.check_call(["sudo", "apt", "install", "-y", "ffmpeg"])
- except subprocess.CalledProcessError:
- try:
- console.print(Panel("📦 正在通过 yum 安装 ffmpeg...", style="cyan"))
- subprocess.check_call(["sudo", "yum", "install", "-y", "ffmpeg"], shell=True)
- except subprocess.CalledProcessError:
- console.print(Panel("❌ 通过包管理器安装 ffmpeg 失败", style="red"))
- else:
- # Windows/MacOS: use conda to install ffmpeg
- console.print(Panel("📦 正在通过 conda 安装 ffmpeg...", style="cyan"))
+ def install_ffmpeg():
+ console.print(Panel("📦 正在通过 conda 安装 ffmpeg...", style="cyan"))
+ try:
subprocess.check_call(["conda", "install", "-y", "ffmpeg"], shell=True)
+ console.print(Panel("✅ FFmpeg 安装完成", style="green"))
+ except subprocess.CalledProcessError:
+ console.print(Panel("❌ 通过 conda 安装 FFmpeg 失败", style="red"))
- import requests
- system = platform.system()
- if system == "Windows":
- ffmpeg_exe = "ffmpeg.exe"
- url = "https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-win64-gpl.zip"
- elif system == "Darwin":
- ffmpeg_exe = "ffmpeg"
- url = "https://evermeet.cx/ffmpeg/getrelease/zip"
- elif system == "Linux":
- ffmpeg_exe = "ffmpeg"
- url = "https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz"
+ def install_noto_font():
+ # 检测 Linux 发行版类型
+ if os.path.exists('/etc/debian_version'):
+ # Debian/Ubuntu 系统
+ cmd = ['sudo', 'apt-get', 'install', '-y', 'fonts-noto']
+ pkg_manager = "apt-get"
+ elif os.path.exists('/etc/redhat-release'):
+ # RHEL/CentOS/Fedora 系统
+ cmd = ['sudo', 'yum', 'install', '-y', 'google-noto*']
+ pkg_manager = "yum"
else:
+ console.print("⚠️ 无法识别的 Linux 发行版,请手动安装 Noto 字体", style="yellow")
return
-
- if os.path.exists(ffmpeg_exe):
- print(f"{ffmpeg_exe} 已存在")
- return
-
- console.print(Panel("📦 正在下载 FFmpeg...", style="cyan"))
- response = requests.get(url)
- if response.status_code == 200:
- filename = "ffmpeg.zip" if system in ["Windows", "Darwin"] else "ffmpeg.tar.xz"
- with open(filename, 'wb') as f:
- f.write(response.content)
- console.print(Panel(f"FFmpeg 下载完成: {filename}", style="cyan"))
-
- console.print(Panel("📦 正在解压 FFmpeg...", style="cyan"))
- if system == "Linux":
- import tarfile
- with tarfile.open(filename) as tar_ref:
- for member in tar_ref.getmembers():
- if member.name.endswith("ffmpeg"):
- member.name = os.path.basename(member.name)
- tar_ref.extract(member)
- else:
- with zipfile.ZipFile(filename, 'r') as zip_ref:
- for file in zip_ref.namelist():
- if file.endswith(ffmpeg_exe):
- zip_ref.extract(file)
- shutil.move(os.path.join(*file.split('/')[:-1], os.path.basename(file)), os.path.basename(file))
- console.print(Panel("📦 正在清理...", style="cyan"))
- os.remove(filename)
- if system == "Windows":
- for item in os.listdir():
- if os.path.isdir(item) and "ffmpeg" in item.lower():
- shutil.rmtree(item)
- console.print(Panel("FFmpeg 解压完成", style="cyan"))
- else:
- console.print(Panel("❌ FFmpeg 下载失败", style="red"))
-
- def install_noto_font():
- if platform.system() == 'Linux':
- try:
- # 首先尝试 apt-get (基于 Debian 的系统)
- subprocess.run(['sudo', 'apt-get', 'install', '-y', 'fonts-noto'], check=True)
- print("使用 apt-get 成功安装了 Noto 字体。")
- except subprocess.CalledProcessError:
- try:
- # 如果 apt-get 失败,尝试 yum (基于 RPM 的系统)
- subprocess.run(['sudo', 'yum', 'install', '-y', 'fonts-noto'], check=True)
- print("使用 yum 成功安装了 Noto 字体。")
- except subprocess.CalledProcessError:
- print("自动安装 Noto 字体失败。请手动安装。")
+ try:
+ subprocess.run(cmd, check=True)
+ console.print(f"✅ 使用 {pkg_manager} 成功安装 Noto 字体", style="green")
+ except subprocess.CalledProcessError:
+ console.print("❌ 安装 Noto 字体失败,请手动安装", style="red")
- install_noto_font()
+ if platform.system() == 'Linux':
+ install_noto_font()
install_requirements()
- download_and_extract_ffmpeg()
+ install_ffmpeg()
console.print(Panel.fit("安装完成", style="bold green"))
console.print("要启动应用程序,请运行:")
diff --git a/install.py b/install.py
index 6c4f1a37..6c7efacb 100644
--- a/install.py
+++ b/install.py
@@ -2,8 +2,6 @@
import platform
import subprocess
import sys
-import zipfile
-import shutil
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
ascii_logo = """
@@ -49,122 +47,60 @@ def main():
choose_mirror()
# Detect system and GPU
- if platform.system() == 'Darwin':
- console.print(Panel("🍎 MacOS detected, installing CPU version of PyTorch... However, it would be extremely slow for transcription.", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
+ has_gpu = platform.system() != 'Darwin' and check_gpu()
+ if has_gpu:
+ console.print(Panel("🎮 NVIDIA GPU detected, installing CUDA version of PyTorch...", style="cyan"))
+ subprocess.check_call(["conda", "install", "-y", "pytorch==2.0.0", "torchaudio==2.0.0", "pytorch-cuda=11.8", "-c", "pytorch", "-c", "nvidia"])
else:
- has_gpu = check_gpu()
- if has_gpu:
- console.print(Panel("🎮 NVIDIA GPU detected, installing CUDA version of PyTorch...", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.0", "torchaudio==2.0.0", "--index-url", "https://download.pytorch.org/whl/cu118"])
- else:
- console.print(Panel("💻 No NVIDIA GPU detected, installing CPU version of PyTorch... However, it would be extremely slow for transcription.", style="cyan"))
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
-
- # Install WhisperX
- console.print(Panel("📦 Installing WhisperX...", style="cyan"))
- current_dir = os.getcwd()
- whisperx_dir = os.path.join(current_dir, "third_party", "whisperX")
- os.chdir(whisperx_dir)
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])
- os.chdir(current_dir)
+ system_name = "🍎 MacOS" if platform.system() == 'Darwin' else "💻 No NVIDIA GPU"
+ console.print(Panel(f"{system_name} detected, installing CPU version of PyTorch... However, it would be extremely slow for transcription.", style="cyan"))
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.1.2", "torchaudio==2.1.2"])
def install_requirements():
try:
- with open("requirements.txt", "r", encoding="utf-8") as file:
- content = file.read()
- with open("requirements.txt", "w", encoding="gbk") as file:
- file.write(content)
- except Exception as e:
- print(f"Error converting requirements.txt: {str(e)}")
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"])
+ subprocess.check_call([
+ sys.executable,
+ "-m",
+ "pip",
+ "install",
+ "-r",
+ "requirements.txt"
+ ], env={**os.environ, "PIP_NO_CACHE_DIR": "0", "PYTHONIOENCODING": "utf-8"})
+ except subprocess.CalledProcessError as e:
+ console.print(Panel(f"❌ Failed to install requirements: {str(e)}", style="red"))
- def download_and_extract_ffmpeg():
- # VL requires both conda/system ffmpeg and ffmpeg.exe...
- system = platform.system()
- if system == "Linux":
- # Linux: use apt or yum to install ffmpeg
- try:
- console.print(Panel("📦 Installing ffmpeg through apt...", style="cyan"))
- subprocess.check_call(["sudo", "apt", "install", "-y", "ffmpeg"])
- except subprocess.CalledProcessError:
- try:
- console.print(Panel("📦 Installing ffmpeg through yum...", style="cyan"))
- subprocess.check_call(["sudo", "yum", "install", "-y", "ffmpeg"], shell=True)
- except subprocess.CalledProcessError:
- console.print(Panel("❌ Failed to install ffmpeg through package manager", style="red"))
- else:
- # Windows/MacOS: use conda to install ffmpeg
- console.print(Panel("📦 Installing ffmpeg through conda...", style="cyan"))
+ def install_ffmpeg():
+ console.print(Panel("📦 Installing ffmpeg through conda...", style="cyan"))
+ try:
subprocess.check_call(["conda", "install", "-y", "ffmpeg"], shell=True)
+ console.print(Panel("✅ FFmpeg installation completed", style="green"))
+ except subprocess.CalledProcessError:
+ console.print(Panel("❌ Failed to install FFmpeg through conda", style="red"))
- import requests
- if system == "Windows":
- ffmpeg_exe = "ffmpeg.exe"
- url = "https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-win64-gpl.zip"
- elif system == "Darwin":
- ffmpeg_exe = "ffmpeg"
- url = "https://evermeet.cx/ffmpeg/getrelease/zip"
- elif system == "Linux":
- ffmpeg_exe = "ffmpeg"
- url = "https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz"
+ def install_noto_font():
+ # Detect Linux distribution type
+ if os.path.exists('/etc/debian_version'):
+ # Debian/Ubuntu systems
+ cmd = ['sudo', 'apt-get', 'install', '-y', 'fonts-noto']
+ pkg_manager = "apt-get"
+ elif os.path.exists('/etc/redhat-release'):
+ # RHEL/CentOS/Fedora systems
+ cmd = ['sudo', 'yum', 'install', '-y', 'google-noto*']
+ pkg_manager = "yum"
else:
+ console.print("⚠️ Unrecognized Linux distribution, please install Noto fonts manually", style="yellow")
return
-
- if os.path.exists(ffmpeg_exe):
- print(f"{ffmpeg_exe} already exists")
- return
-
- console.print(Panel("📦 Downloading FFmpeg...", style="cyan"))
- response = requests.get(url)
- if response.status_code == 200:
- filename = "ffmpeg.zip" if system in ["Windows", "Darwin"] else "ffmpeg.tar.xz"
- with open(filename, 'wb') as f:
- f.write(response.content)
- console.print(Panel(f"FFmpeg downloaded: {filename}", style="cyan"))
-
- console.print(Panel("📦 Extracting FFmpeg...", style="cyan"))
- if system == "Linux":
- import tarfile
- with tarfile.open(filename) as tar_ref:
- for member in tar_ref.getmembers():
- if member.name.endswith("ffmpeg"):
- member.name = os.path.basename(member.name)
- tar_ref.extract(member)
- else:
- with zipfile.ZipFile(filename, 'r') as zip_ref:
- for file in zip_ref.namelist():
- if file.endswith(ffmpeg_exe):
- zip_ref.extract(file)
- shutil.move(os.path.join(*file.split('/')[:-1], os.path.basename(file)), os.path.basename(file))
- console.print(Panel("📦 Cleaning up...", style="cyan"))
- os.remove(filename)
- if system == "Windows":
- for item in os.listdir():
- if os.path.isdir(item) and "ffmpeg" in item.lower():
- shutil.rmtree(item)
- console.print(Panel("FFmpeg extraction completed", style="cyan"))
- else:
- console.print(Panel("❌ Failed to download FFmpeg", style="red"))
-
- def install_noto_font():
- if platform.system() == 'Linux':
- try:
- # Try apt-get first (Debian-based systems)
- subprocess.run(['sudo', 'apt-get', 'install', '-y', 'fonts-noto'], check=True)
- print("Noto fonts installed successfully using apt-get.")
- except subprocess.CalledProcessError:
- try:
- # If apt-get fails, try yum (RPM-based systems)
- subprocess.run(['sudo', 'yum', 'install', '-y', 'fonts-noto'], check=True)
- print("Noto fonts installed successfully using yum.")
- except subprocess.CalledProcessError:
- print("Failed to install Noto fonts automatically. Please install them manually.")
+ try:
+ subprocess.run(cmd, check=True)
+ console.print(f"✅ Successfully installed Noto fonts using {pkg_manager}", style="green")
+ except subprocess.CalledProcessError:
+ console.print("❌ Failed to install Noto fonts, please install manually", style="red")
- install_noto_font()
+ if platform.system() == 'Linux':
+ install_noto_font()
install_requirements()
- download_and_extract_ffmpeg()
+ install_ffmpeg()
console.print(Panel.fit("Installation completed", style="bold green"))
console.print("To start the application, run:")
diff --git a/requirements.txt b/requirements.txt
index 287d5daf..5a848cf7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,8 @@
azure-cognitiveservices-speech==1.40.0
librosa==0.10.2.post1
+pytorch-lightning==2.3.3
+lightning==2.3.3
+transformers==4.39.3
# moviepy only in fishtts
moviepy==1.0.3
numpy==1.26.4
@@ -12,14 +15,18 @@ PyYAML==6.0.2
replicate==0.33.0
requests==2.32.3
resampy==0.4.3
-spacy==3.7.6
+spacy==3.7.4
streamlit==1.38.0
yt-dlp
json-repair
ruamel.yaml
autocorrect-py
+
demucs[dev] @ git+https://github.com/adefossez/demucs
+whisperx @ git+https://github.com/m-bain/whisperx.git
syllables
pypinyin
g2p-en
+
+
diff --git a/third_party/whisperX/.github/FUNDING.yml b/third_party/whisperX/.github/FUNDING.yml
deleted file mode 100644
index d517cfb8..00000000
--- a/third_party/whisperX/.github/FUNDING.yml
+++ /dev/null
@@ -1 +0,0 @@
-custom: https://www.buymeacoffee.com/maxhbain
diff --git a/third_party/whisperX/.gitignore b/third_party/whisperX/.gitignore
deleted file mode 100644
index 540c1326..00000000
--- a/third_party/whisperX/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-whisperx.egg-info/
-**/__pycache__/
-.ipynb_checkpoints
diff --git a/third_party/whisperX/EXAMPLES.md b/third_party/whisperX/EXAMPLES.md
deleted file mode 100644
index d9dc8e41..00000000
--- a/third_party/whisperX/EXAMPLES.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# More Examples
-
-## Other Languages
-
-For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
-
-Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
-
-
-If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
-
-### French
- whisperx --model large --language fr examples/sample_fr_01.wav
-
-
-https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
-
-
-### German
- whisperx --model large --language de examples/sample_de_01.wav
-
-
-https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
-
-
-### Italian
- whisperx --model large --language de examples/sample_it_01.wav
-
-
-https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
-
-
-### Japanese
- whisperx --model large --language ja examples/sample_ja_01.wav
-
-
-https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
diff --git a/third_party/whisperX/LICENSE b/third_party/whisperX/LICENSE
deleted file mode 100644
index 21ec9f0a..00000000
--- a/third_party/whisperX/LICENSE
+++ /dev/null
@@ -1,24 +0,0 @@
-BSD 2-Clause License
-
-Copyright (c) 2024, Max Bain
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/third_party/whisperX/MANIFEST.in b/third_party/whisperX/MANIFEST.in
deleted file mode 100644
index 96ef3406..00000000
--- a/third_party/whisperX/MANIFEST.in
+++ /dev/null
@@ -1,4 +0,0 @@
-include whisperx/assets/*
-include whisperx/assets/gpt2/*
-include whisperx/assets/multilingual/*
-include whisperx/normalizers/english.json
diff --git a/third_party/whisperX/README.md b/third_party/whisperX/README.md
deleted file mode 100644
index 32e86665..00000000
--- a/third_party/whisperX/README.md
+++ /dev/null
@@ -1,302 +0,0 @@
-
WhisperX
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-This repository provides fast automatic speech recognition (70x realtime with large-v2) with word-level timestamps and speaker diarization.
-
-- ⚡️ Batched inference for 70x realtime transcription using whisper large-v2
-- 🪶 [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend, requires <8GB gpu memory for large-v2 with beam_size=5
-- 🎯 Accurate word-level timestamps using wav2vec2 alignment
-- 👯♂️ Multispeaker ASR using speaker diarization from [pyannote-audio](https://github.com/pyannote/pyannote-audio) (speaker ID labels)
-- 🗣️ VAD preprocessing, reduces hallucination & batching with no WER degradation
-
-
-
-**Whisper** is an ASR model [developed by OpenAI](https://github.com/openai/whisper), trained on a large dataset of diverse audio. Whilst it does produces highly accurate transcriptions, the corresponding timestamps are at the utterance-level, not per word, and can be inaccurate by several seconds. OpenAI's whisper does not natively support batching.
-
-**Phoneme-Based ASR** A suite of models finetuned to recognise the smallest unit of speech distinguishing one word from another, e.g. the element p in "tap". A popular example model is [wav2vec2.0](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self).
-
-**Forced Alignment** refers to the process by which orthographic transcriptions are aligned to audio recordings to automatically generate phone level segmentation.
-
-**Voice Activity Detection (VAD)** is the detection of the presence or absence of human speech.
-
-**Speaker Diarization** is the process of partitioning an audio stream containing human speech into homogeneous segments according to the identity of each speaker.
-
-New🚨
-
-- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) 🏆
-- _WhisperX_ accepted at INTERSPEECH 2023
-- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
-- v3 released, 70x speed-up open-sourced. Using batched whisper with [faster-whisper](https://github.com/guillaumekln/faster-whisper) backend!
-- v2 released, code cleanup, imports whisper library VAD filtering is now turned on by default, as in the paper.
-- Paper drop🎓👨🏫! Please see our [ArxiV preprint](https://arxiv.org/abs/2303.00747) for benchmarking and details of WhisperX. We also introduce more efficient batch inference resulting in large-v2 with *60-70x REAL TIME speed.
-
-Setup ⚙️
-Tested for PyTorch 2.0, Python 3.10 (use other versions at your own risk!)
-
-GPU execution requires the NVIDIA libraries cuBLAS 11.x and cuDNN 8.x to be installed on the system. Please refer to the [CTranslate2 documentation](https://opennmt.net/CTranslate2/installation.html).
-
-
-### 1. Create Python3.10 environment
-
-`conda create --name whisperx python=3.10`
-
-`conda activate whisperx`
-
-
-### 2. Install PyTorch, e.g. for Linux and Windows CUDA11.8:
-
-`conda install pytorch==2.0.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia`
-
-See other methods [here.](https://pytorch.org/get-started/previous-versions/#v200)
-
-### 3. Install this repo
-
-`pip install git+https://github.com/m-bain/whisperx.git`
-
-If already installed, update package to most recent commit
-
-`pip install git+https://github.com/m-bain/whisperx.git --upgrade`
-
-If wishing to modify this package, clone and install in editable mode:
-```
-$ git clone https://github.com/m-bain/whisperX.git
-$ cd whisperX
-$ pip install -e .
-```
-
-You may also need to install ffmpeg, rust etc. Follow openAI instructions here https://github.com/openai/whisper#setup.
-
-### Speaker Diarization
-To **enable Speaker Diarization**, include your Hugging Face access token (read) that you can generate from [Here](https://huggingface.co/settings/tokens) after the `--hf_token` argument and accept the user agreement for the following models: [Segmentation](https://huggingface.co/pyannote/segmentation-3.0) and [Speaker-Diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1) (if you choose to use Speaker-Diarization 2.x, follow requirements [here](https://huggingface.co/pyannote/speaker-diarization) instead.)
-
-> **Note**
-> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
-
-
-Usage 💬 (command line)
-
-### English
-
-Run whisper on example segment (using default params, whisper small) add `--highlight_words True` to visualise word timings in the .srt file.
-
- whisperx examples/sample01.wav
-
-
-Result using *WhisperX* with forced alignment to wav2vec2.0 large:
-
-https://user-images.githubusercontent.com/36994049/208253969-7e35fe2a-7541-434a-ae91-8e919540555d.mp4
-
-Compare this to original whisper out the box, where many transcriptions are out of sync:
-
-https://user-images.githubusercontent.com/36994049/207743923-b4f0d537-29ae-4be2-b404-bb941db73652.mov
-
-
-For increased timestamp accuracy, at the cost of higher gpu mem, use bigger models (bigger alignment model not found to be that helpful, see paper) e.g.
-
- whisperx examples/sample01.wav --model large-v2 --align_model WAV2VEC2_ASR_LARGE_LV60K_960H --batch_size 4
-
-
-To label the transcript with speaker ID's (set number of speakers if known e.g. `--min_speakers 2` `--max_speakers 2`):
-
- whisperx examples/sample01.wav --model large-v2 --diarize --highlight_words True
-
-To run on CPU instead of GPU (and for running on Mac OS X):
-
- whisperx examples/sample01.wav --compute_type int8
-
-### Other languages
-
-The phoneme ASR alignment model is *language-specific*, for tested languages these models are [automatically picked from torchaudio pipelines or huggingface](https://github.com/m-bain/whisperX/blob/e909f2f766b23b2000f2d95df41f9b844ac53e49/whisperx/transcribe.py#L22).
-Just pass in the `--language` code, and use the whisper `--model large`.
-
-Currently default models provided for `{en, fr, de, es, it, ja, zh, nl, uk, pt}`. If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
-
-
-#### E.g. German
- whisperx --model large-v2 --language de examples/sample_de_01.wav
-
-https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
-
-
-See more examples in other languages [here](EXAMPLES.md).
-
-## Python usage 🐍
-
-```python
-import whisperx
-import gc
-
-device = "cuda"
-audio_file = "audio.mp3"
-batch_size = 16 # reduce if low on GPU mem
-compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
-
-# 1. Transcribe with original whisper (batched)
-model = whisperx.load_model("large-v2", device, compute_type=compute_type)
-
-# save model to local path (optional)
-# model_dir = "/path/"
-# model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=model_dir)
-
-audio = whisperx.load_audio(audio_file)
-result = model.transcribe(audio, batch_size=batch_size)
-print(result["segments"]) # before alignment
-
-# delete model if low on GPU resources
-# import gc; gc.collect(); torch.cuda.empty_cache(); del model
-
-# 2. Align whisper output
-model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
-result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
-
-print(result["segments"]) # after alignment
-
-# delete model if low on GPU resources
-# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
-
-# 3. Assign speaker labels
-diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
-
-# add min/max number of speakers if known
-diarize_segments = diarize_model(audio)
-# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
-
-result = whisperx.assign_word_speakers(diarize_segments, result)
-print(diarize_segments)
-print(result["segments"]) # segments are now assigned speaker IDs
-```
-
-## Demos 🚀
-
-[](https://replicate.com/victor-upmeet/whisperx)
-[](https://replicate.com/daanelson/whisperx)
-[](https://replicate.com/carnifexer/whisperx)
-
-If you don't have access to your own GPUs, use the links above to try out WhisperX.
-
-Technical Details 👷♂️
-
-For specific details on the batching and alignment, the effect of VAD, as well as the chosen alignment model, see the preprint [paper](https://www.robots.ox.ac.uk/~vgg/publications/2023/Bain23/bain23.pdf).
-
-To reduce GPU memory requirements, try any of the following (2. & 3. can affect quality):
-1. reduce batch size, e.g. `--batch_size 4`
-2. use a smaller ASR model `--model base`
-3. Use lighter compute type `--compute_type int8`
-
-Transcription differences from openai's whisper:
-1. Transcription without timestamps. To enable single pass batching, whisper inference is performed `--without_timestamps True`, this ensures 1 forward pass per sample in the batch. However, this can cause discrepancies the default whisper output.
-2. VAD-based segment transcription, unlike the buffered transcription of openai's. In Wthe WhisperX paper we show this reduces WER, and enables accurate batched inference
-3. `--condition_on_prev_text` is set to `False` by default (reduces hallucination)
-
-Limitations ⚠️
-
-- Transcript words which do not contain characters in the alignment models dictionary e.g. "2014." or "£13.60" cannot be aligned and therefore are not given a timing.
-- Overlapping speech is not handled particularly well by whisper nor whisperx
-- Diarization is far from perfect
-- Language specific wav2vec2 model is needed
-
-
-Contribute 🧑🏫
-
-If you are multilingual, a major way you can contribute to this project is to find phoneme models on huggingface (or train your own) and test them on speech for the target language. If the results look good send a pull request and some examples showing its success.
-
-Bug finding and pull requests are also highly appreciated to keep this project going, since it's already diverging from the original research scope.
-
-TODO 🗓
-
-* [x] Multilingual init
-
-* [x] Automatic align model selection based on language detection
-
-* [x] Python usage
-
-* [x] Incorporating speaker diarization
-
-* [x] Model flush, for low gpu mem resources
-
-* [x] Faster-whisper backend
-
-* [x] Add max-line etc. see (openai's whisper utils.py)
-
-* [x] Sentence-level segments (nltk toolbox)
-
-* [x] Improve alignment logic
-
-* [ ] update examples with diarization and word highlighting
-
-* [ ] Subtitle .ass output <- bring this back (removed in v3)
-
-* [ ] Add benchmarking code (TEDLIUM for spd/WER & word segmentation)
-
-* [ ] Allow silero-vad as alternative VAD option
-
-* [ ] Improve diarization (word level). *Harder than first thought...*
-
-
-
-
-
-Contact maxhbain@gmail.com for queries.
-
-
-
-
-Acknowledgements 🙏
-
-This work, and my PhD, is supported by the [VGG (Visual Geometry Group)](https://www.robots.ox.ac.uk/~vgg/) and the University of Oxford.
-
-Of course, this is builds on [openAI's whisper](https://github.com/openai/whisper).
-Borrows important alignment code from [PyTorch tutorial on forced alignment](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html)
-And uses the wonderful pyannote VAD / Diarization https://github.com/pyannote/pyannote-audio
-
-
-Valuable VAD & Diarization Models from [pyannote audio][https://github.com/pyannote/pyannote-audio]
-
-Great backend from [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2)
-
-Those who have [supported this work financially](https://www.buymeacoffee.com/maxhbain) 🙏
-
-Finally, thanks to the OS [contributors](https://github.com/m-bain/whisperX/graphs/contributors) of this project, keeping it going and identifying bugs.
-
-Citation
-If you use this in your research, please cite the paper:
-
-```bibtex
-@article{bain2022whisperx,
- title={WhisperX: Time-Accurate Speech Transcription of Long-Form Audio},
- author={Bain, Max and Huh, Jaesung and Han, Tengda and Zisserman, Andrew},
- journal={INTERSPEECH 2023},
- year={2023}
-}
-```
diff --git a/third_party/whisperX/figures/pipeline.png b/third_party/whisperX/figures/pipeline.png
deleted file mode 100644
index 232ea788..00000000
Binary files a/third_party/whisperX/figures/pipeline.png and /dev/null differ
diff --git a/third_party/whisperX/requirements.txt b/third_party/whisperX/requirements.txt
deleted file mode 100644
index 865abd1f..00000000
--- a/third_party/whisperX/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-torch>=2
-torchaudio>=2
-faster-whisper==1.0.0
-transformers
-pandas
-setuptools>=65
-nltk
diff --git a/third_party/whisperX/setup.py b/third_party/whisperX/setup.py
deleted file mode 100644
index 40db6cc9..00000000
--- a/third_party/whisperX/setup.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import os
-import platform
-
-import pkg_resources
-from setuptools import find_packages, setup
-
-setup(
- name="whisperx",
- py_modules=["whisperx"],
- version="3.1.1",
- description="Time-Accurate Automatic Speech Recognition using Whisper.",
- readme="README.md",
- python_requires=">=3.8",
- author="Max Bain",
- url="https://github.com/m-bain/whisperx",
- license="MIT",
- packages=find_packages(exclude=["tests*"]),
- install_requires=[
- str(r)
- for r in pkg_resources.parse_requirements(
- open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
- )
- ]
- + [f"pyannote.audio==3.1.1"],
- entry_points={
- "console_scripts": ["whisperx=whisperx.transcribe:cli"],
- },
- include_package_data=True,
- extras_require={"dev": ["pytest"]},
-)
diff --git a/third_party/whisperX/whisperx/SubtitlesProcessor.py b/third_party/whisperX/whisperx/SubtitlesProcessor.py
deleted file mode 100644
index 420699e2..00000000
--- a/third_party/whisperX/whisperx/SubtitlesProcessor.py
+++ /dev/null
@@ -1,227 +0,0 @@
-import math
-from conjunctions import get_conjunctions, get_comma
-from typing import TextIO
-
-def normal_round(n):
- if n - math.floor(n) < 0.5:
- return math.floor(n)
- return math.ceil(n)
-
-
-def format_timestamp(seconds: float, is_vtt: bool = False):
-
- assert seconds >= 0, "non-negative timestamp expected"
- milliseconds = round(seconds * 1000.0)
-
- hours = milliseconds // 3_600_000
- milliseconds -= hours * 3_600_000
-
- minutes = milliseconds // 60_000
- milliseconds -= minutes * 60_000
-
- seconds = milliseconds // 1_000
- milliseconds -= seconds * 1_000
-
- separator = '.' if is_vtt else ','
-
- hours_marker = f"{hours:02d}:"
- return (
- f"{hours_marker}{minutes:02d}:{seconds:02d}{separator}{milliseconds:03d}"
- )
-
-
-
-class SubtitlesProcessor:
- def __init__(self, segments, lang, max_line_length = 45, min_char_length_splitter = 30, is_vtt = False):
- self.comma = get_comma(lang)
- self.conjunctions = set(get_conjunctions(lang))
- self.segments = segments
- self.lang = lang
- self.max_line_length = max_line_length
- self.min_char_length_splitter = min_char_length_splitter
- self.is_vtt = is_vtt
- complex_script_languages = ['th', 'lo', 'my', 'km', 'am', 'ko', 'ja', 'zh', 'ti', 'ta', 'te', 'kn', 'ml', 'hi', 'ne', 'mr', 'ar', 'fa', 'ur', 'ka']
- if self.lang in complex_script_languages:
- self.max_line_length = 30
- self.min_char_length_splitter = 20
-
- def estimate_timestamp_for_word(self, words, i, next_segment_start_time=None):
- k = 0.25
- has_prev_end = i > 0 and 'end' in words[i - 1]
- has_next_start = i < len(words) - 1 and 'start' in words[i + 1]
-
- if has_prev_end:
- words[i]['start'] = words[i - 1]['end']
- if has_next_start:
- words[i]['end'] = words[i + 1]['start']
- else:
- if next_segment_start_time:
- words[i]['end'] = next_segment_start_time if next_segment_start_time - words[i - 1]['end'] <= 1 else next_segment_start_time - 0.5
- else:
- words[i]['end'] = words[i]['start'] + len(words[i]['word']) * k
-
- elif has_next_start:
- words[i]['start'] = words[i + 1]['start'] - len(words[i]['word']) * k
- words[i]['end'] = words[i + 1]['start']
-
- else:
- if next_segment_start_time:
- words[i]['start'] = next_segment_start_time - 1
- words[i]['end'] = next_segment_start_time - 0.5
- else:
- words[i]['start'] = 0
- words[i]['end'] = 0
-
-
-
- def process_segments(self, advanced_splitting=True):
- subtitles = []
- for i, segment in enumerate(self.segments):
- next_segment_start_time = self.segments[i + 1]['start'] if i + 1 < len(self.segments) else None
-
- if advanced_splitting:
-
- split_points = self.determine_advanced_split_points(segment, next_segment_start_time)
- subtitles.extend(self.generate_subtitles_from_split_points(segment, split_points, next_segment_start_time))
- else:
- words = segment['words']
- for i, word in enumerate(words):
- if 'start' not in word or 'end' not in word:
- self.estimate_timestamp_for_word(words, i, next_segment_start_time)
-
- subtitles.append({
- 'start': segment['start'],
- 'end': segment['end'],
- 'text': segment['text']
- })
-
- return subtitles
-
- def determine_advanced_split_points(self, segment, next_segment_start_time=None):
- split_points = []
- last_split_point = 0
- char_count = 0
-
- words = segment.get('words', segment['text'].split())
- add_space = 0 if self.lang in ['zh', 'ja'] else 1
-
- total_char_count = sum(len(word['word']) if isinstance(word, dict) else len(word) + add_space for word in words)
- char_count_after = total_char_count
-
- for i, word in enumerate(words):
- word_text = word['word'] if isinstance(word, dict) else word
- word_length = len(word_text) + add_space
- char_count += word_length
- char_count_after -= word_length
-
- char_count_before = char_count - word_length
-
- if isinstance(word, dict) and ('start' not in word or 'end' not in word):
- self.estimate_timestamp_for_word(words, i, next_segment_start_time)
-
- if char_count >= self.max_line_length:
- midpoint = normal_round((last_split_point + i) / 2)
- if char_count_before >= self.min_char_length_splitter:
- split_points.append(midpoint)
- last_split_point = midpoint + 1
- char_count = sum(len(words[j]['word']) if isinstance(words[j], dict) else len(words[j]) + add_space for j in range(last_split_point, i + 1))
-
- elif word_text.endswith(self.comma) and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
- split_points.append(i)
- last_split_point = i + 1
- char_count = 0
-
- elif word_text.lower() in self.conjunctions and char_count_before >= self.min_char_length_splitter and char_count_after >= self.min_char_length_splitter:
- split_points.append(i - 1)
- last_split_point = i
- char_count = word_length
-
- return split_points
-
-
- def generate_subtitles_from_split_points(self, segment, split_points, next_start_time=None):
- subtitles = []
-
- words = segment.get('words', segment['text'].split())
- total_word_count = len(words)
- total_time = segment['end'] - segment['start']
- elapsed_time = segment['start']
- prefix = ' ' if self.lang not in ['zh', 'ja'] else ''
- start_idx = 0
- for split_point in split_points:
-
- fragment_words = words[start_idx:split_point + 1]
- current_word_count = len(fragment_words)
-
-
- if isinstance(fragment_words[0], dict):
- start_time = fragment_words[0]['start']
- end_time = fragment_words[-1]['end']
- next_start_time_for_word = words[split_point + 1]['start'] if split_point + 1 < len(words) else None
- if next_start_time_for_word and (next_start_time_for_word - end_time) <= 0.8:
- end_time = next_start_time_for_word
- else:
- fragment = prefix.join(fragment_words).strip()
- current_duration = (current_word_count / total_word_count) * total_time
- start_time = elapsed_time
- end_time = elapsed_time + current_duration
- elapsed_time += current_duration
-
-
- subtitles.append({
- 'start': start_time,
- 'end': end_time,
- 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
- })
-
- start_idx = split_point + 1
-
- # Handle the last fragment
- if start_idx < len(words):
- fragment_words = words[start_idx:]
- current_word_count = len(fragment_words)
-
- if isinstance(fragment_words[0], dict):
- start_time = fragment_words[0]['start']
- end_time = fragment_words[-1]['end']
- else:
- fragment = prefix.join(fragment_words).strip()
- current_duration = (current_word_count / total_word_count) * total_time
- start_time = elapsed_time
- end_time = elapsed_time + current_duration
-
- if next_start_time and (next_start_time - end_time) <= 0.8:
- end_time = next_start_time
-
- subtitles.append({
- 'start': start_time,
- 'end': end_time if end_time is not None else segment['end'],
- 'text': fragment if not isinstance(fragment_words[0], dict) else prefix.join(word['word'] for word in fragment_words)
- })
-
- return subtitles
-
-
-
- def save(self, filename="subtitles.srt", advanced_splitting=True):
-
- subtitles = self.process_segments(advanced_splitting)
-
- def write_subtitle(file, idx, start_time, end_time, text):
-
- file.write(f"{idx}\n")
- file.write(f"{start_time} --> {end_time}\n")
- file.write(text + "\n\n")
-
- with open(filename, 'w', encoding='utf-8') as file:
- if self.is_vtt:
- file.write("WEBVTT\n\n")
-
- if advanced_splitting:
- for idx, subtitle in enumerate(subtitles, 1):
- start_time = format_timestamp(subtitle['start'], self.is_vtt)
- end_time = format_timestamp(subtitle['end'], self.is_vtt)
- text = subtitle['text'].strip()
- write_subtitle(file, idx, start_time, end_time, text)
-
- return len(subtitles)
\ No newline at end of file
diff --git a/third_party/whisperX/whisperx/__init__.py b/third_party/whisperX/whisperx/__init__.py
deleted file mode 100644
index 20abaaed..00000000
--- a/third_party/whisperX/whisperx/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .transcribe import load_model
-from .alignment import load_align_model, align
-from .audio import load_audio
-from .diarize import assign_word_speakers, DiarizationPipeline
\ No newline at end of file
diff --git a/third_party/whisperX/whisperx/__main__.py b/third_party/whisperX/whisperx/__main__.py
deleted file mode 100644
index bc9b04a3..00000000
--- a/third_party/whisperX/whisperx/__main__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .transcribe import cli
-
-
-cli()
diff --git a/third_party/whisperX/whisperx/alignment.py b/third_party/whisperX/whisperx/alignment.py
deleted file mode 100644
index 964217e2..00000000
--- a/third_party/whisperX/whisperx/alignment.py
+++ /dev/null
@@ -1,470 +0,0 @@
-""""
-Forced Alignment with Whisper
-C. Max Bain
-"""
-from dataclasses import dataclass
-from typing import Iterable, Union, List
-
-import numpy as np
-import pandas as pd
-import torch
-import torchaudio
-from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
-
-from .audio import SAMPLE_RATE, load_audio
-from .utils import interpolate_nans
-from .types import AlignedTranscriptionResult, SingleSegment, SingleAlignedSegment, SingleWordSegment
-import nltk
-from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters
-
-PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
-
-LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
-
-DEFAULT_ALIGN_MODELS_TORCH = {
- "en": "WAV2VEC2_ASR_BASE_960H",
- "fr": "VOXPOPULI_ASR_BASE_10K_FR",
- "de": "VOXPOPULI_ASR_BASE_10K_DE",
- "es": "VOXPOPULI_ASR_BASE_10K_ES",
- "it": "VOXPOPULI_ASR_BASE_10K_IT",
-}
-
-DEFAULT_ALIGN_MODELS_HF = {
- "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
- "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
- "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
- "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
- "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
- "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
- "cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
- "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
- "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
- "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
- "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
- "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
- "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
- "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
- "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
- "he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
- "vi": 'nguyenvulebinh/wav2vec2-base-vi',
- "ko": "kresnik/wav2vec2-large-xlsr-korean",
- "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
- "te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
- "hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
- "ca": "softcatala/wav2vec2-large-xlsr-catala",
- "ml": "gvs/wav2vec2-large-xlsr-malayalam",
- "no": "NbAiLab/nb-wav2vec2-1b-bokmaal",
- "nn": "NbAiLab/nb-wav2vec2-300m-nynorsk",
- "sk": "comodoro/wav2vec2-xls-r-300m-sk-cv8",
- "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
- "hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
-}
-
-
-def load_align_model(language_code, device, model_name=None, model_dir=None):
- if model_name is None:
- # use default model
- if language_code in DEFAULT_ALIGN_MODELS_TORCH:
- model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
- elif language_code in DEFAULT_ALIGN_MODELS_HF:
- model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
- else:
- print(f"There is no default alignment model set for this language ({language_code}).\
- Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
- raise ValueError(f"No default align-model for language: {language_code}")
-
- if model_name in torchaudio.pipelines.__all__:
- pipeline_type = "torchaudio"
- bundle = torchaudio.pipelines.__dict__[model_name]
- align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device)
- labels = bundle.get_labels()
- align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
- else:
- try:
- processor = Wav2Vec2Processor.from_pretrained(model_name)
- align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
- except Exception as e:
- print(e)
- print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
- raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
- pipeline_type = "huggingface"
- align_model = align_model.to(device)
- labels = processor.tokenizer.get_vocab()
- align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
-
- align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
-
- return align_model, align_metadata
-
-
-def align(
- transcript: Iterable[SingleSegment],
- model: torch.nn.Module,
- align_model_metadata: dict,
- audio: Union[str, np.ndarray, torch.Tensor],
- device: str,
- interpolate_method: str = "nearest",
- return_char_alignments: bool = False,
- print_progress: bool = False,
- combined_progress: bool = False,
-) -> AlignedTranscriptionResult:
- """
- Align phoneme recognition predictions to known transcription.
- """
-
- if not torch.is_tensor(audio):
- if isinstance(audio, str):
- audio = load_audio(audio)
- audio = torch.from_numpy(audio)
- if len(audio.shape) == 1:
- audio = audio.unsqueeze(0)
-
- MAX_DURATION = audio.shape[1] / SAMPLE_RATE
-
- model_dictionary = align_model_metadata["dictionary"]
- model_lang = align_model_metadata["language"]
- model_type = align_model_metadata["type"]
-
- # 1. Preprocess to keep only characters in dictionary
- total_segments = len(transcript)
- for sdx, segment in enumerate(transcript):
- # strip spaces at beginning / end, but keep track of the amount.
- if print_progress:
- base_progress = ((sdx + 1) / total_segments) * 100
- percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
- print(f"Progress: {percent_complete:.2f}%...")
-
- num_leading = len(segment["text"]) - len(segment["text"].lstrip())
- num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
- text = segment["text"]
-
- # split into words
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
- per_word = text.split(" ")
- else:
- per_word = text
-
- clean_char, clean_cdx = [], []
- for cdx, char in enumerate(text):
- char_ = char.lower()
- # wav2vec2 models use "|" character to represent spaces
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
- char_ = char_.replace(" ", "|")
-
- # ignore whitespace at beginning and end of transcript
- if cdx < num_leading:
- pass
- elif cdx > len(text) - num_trailing - 1:
- pass
- elif char_ in model_dictionary.keys():
- clean_char.append(char_)
- clean_cdx.append(cdx)
-
- clean_wdx = []
- for wdx, wrd in enumerate(per_word):
- if any([c in model_dictionary.keys() for c in wrd]):
- clean_wdx.append(wdx)
-
-
- punkt_param = PunktParameters()
- punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
- sentence_splitter = PunktSentenceTokenizer(punkt_param)
- sentence_spans = list(sentence_splitter.span_tokenize(text))
-
- segment["clean_char"] = clean_char
- segment["clean_cdx"] = clean_cdx
- segment["clean_wdx"] = clean_wdx
- segment["sentence_spans"] = sentence_spans
-
- aligned_segments: List[SingleAlignedSegment] = []
-
- # 2. Get prediction matrix from alignment model & align
- for sdx, segment in enumerate(transcript):
-
- t1 = segment["start"]
- t2 = segment["end"]
- text = segment["text"]
-
- aligned_seg: SingleAlignedSegment = {
- "start": t1,
- "end": t2,
- "text": text,
- "words": [],
- }
-
- if return_char_alignments:
- aligned_seg["chars"] = []
-
- # check we can align
- if len(segment["clean_char"]) == 0:
- print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
- aligned_segments.append(aligned_seg)
- continue
-
- if t1 >= MAX_DURATION:
- print(f'Failed to align segment ("{segment["text"]}"): original start time longer than audio duration, skipping...')
- aligned_segments.append(aligned_seg)
- continue
-
- text_clean = "".join(segment["clean_char"])
- tokens = [model_dictionary[c] for c in text_clean]
-
- f1 = int(t1 * SAMPLE_RATE)
- f2 = int(t2 * SAMPLE_RATE)
-
- # TODO: Probably can get some speedup gain with batched inference here
- waveform_segment = audio[:, f1:f2]
- # Handle the minimum input length for wav2vec2 models
- if waveform_segment.shape[-1] < 400:
- lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device)
- waveform_segment = torch.nn.functional.pad(
- waveform_segment, (0, 400 - waveform_segment.shape[-1])
- )
- else:
- lengths = None
-
- with torch.inference_mode():
- if model_type == "torchaudio":
- emissions, _ = model(waveform_segment.to(device), lengths=lengths)
- elif model_type == "huggingface":
- emissions = model(waveform_segment.to(device)).logits
- else:
- raise NotImplementedError(f"Align model of type {model_type} not supported.")
- emissions = torch.log_softmax(emissions, dim=-1)
-
- emission = emissions[0].cpu().detach()
-
- blank_id = 0
- for char, code in model_dictionary.items():
- if char == '[pad]' or char == '':
- blank_id = code
-
- trellis = get_trellis(emission, tokens, blank_id)
- path = backtrack(trellis, emission, tokens, blank_id)
-
- if path is None:
- print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
- aligned_segments.append(aligned_seg)
- continue
-
- char_segments = merge_repeats(path, text_clean)
-
- duration = t2 -t1
- ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
-
- # assign timestamps to aligned characters
- char_segments_arr = []
- word_idx = 0
- for cdx, char in enumerate(text):
- start, end, score = None, None, None
- if cdx in segment["clean_cdx"]:
- char_seg = char_segments[segment["clean_cdx"].index(cdx)]
- start = round(char_seg.start * ratio + t1, 3)
- end = round(char_seg.end * ratio + t1, 3)
- score = round(char_seg.score, 3)
-
- char_segments_arr.append(
- {
- "char": char,
- "start": start,
- "end": end,
- "score": score,
- "word-idx": word_idx,
- }
- )
-
- # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now...
- if model_lang in LANGUAGES_WITHOUT_SPACES:
- word_idx += 1
- elif cdx == len(text) - 1 or text[cdx+1] == " ":
- word_idx += 1
-
- char_segments_arr = pd.DataFrame(char_segments_arr)
-
- aligned_subsegments = []
- # assign sentence_idx to each character index
- char_segments_arr["sentence-idx"] = None
- for sdx, (sstart, send) in enumerate(segment["sentence_spans"]):
- curr_chars = char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send)]
- char_segments_arr.loc[(char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), "sentence-idx"] = sdx
-
- sentence_text = text[sstart:send]
- sentence_start = curr_chars["start"].min()
- end_chars = curr_chars[curr_chars["char"] != ' ']
- sentence_end = end_chars["end"].max()
- sentence_words = []
-
- for word_idx in curr_chars["word-idx"].unique():
- word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx]
- word_text = "".join(word_chars["char"].tolist()).strip()
- if len(word_text) == 0:
- continue
-
- # dont use space character for alignment
- word_chars = word_chars[word_chars["char"] != " "]
-
- word_start = word_chars["start"].min()
- word_end = word_chars["end"].max()
- word_score = round(word_chars["score"].mean(), 3)
-
- # -1 indicates unalignable
- word_segment = {"word": word_text}
-
- if not np.isnan(word_start):
- word_segment["start"] = word_start
- if not np.isnan(word_end):
- word_segment["end"] = word_end
- if not np.isnan(word_score):
- word_segment["score"] = word_score
-
- sentence_words.append(word_segment)
-
- aligned_subsegments.append({
- "text": sentence_text,
- "start": sentence_start,
- "end": sentence_end,
- "words": sentence_words,
- })
-
- if return_char_alignments:
- curr_chars = curr_chars[["char", "start", "end", "score"]]
- curr_chars.fillna(-1, inplace=True)
- curr_chars = curr_chars.to_dict("records")
- curr_chars = [{key: val for key, val in char.items() if val != -1} for char in curr_chars]
- aligned_subsegments[-1]["chars"] = curr_chars
-
- aligned_subsegments = pd.DataFrame(aligned_subsegments)
- aligned_subsegments["start"] = interpolate_nans(aligned_subsegments["start"], method=interpolate_method)
- aligned_subsegments["end"] = interpolate_nans(aligned_subsegments["end"], method=interpolate_method)
- # concatenate sentences with same timestamps
- agg_dict = {"text": " ".join, "words": "sum"}
- if model_lang in LANGUAGES_WITHOUT_SPACES:
- agg_dict["text"] = "".join
- if return_char_alignments:
- agg_dict["chars"] = "sum"
- aligned_subsegments= aligned_subsegments.groupby(["start", "end"], as_index=False).agg(agg_dict)
- aligned_subsegments = aligned_subsegments.to_dict('records')
- aligned_segments += aligned_subsegments
-
- # create word_segments list
- word_segments: List[SingleWordSegment] = []
- for segment in aligned_segments:
- word_segments += segment["words"]
-
- return {"segments": aligned_segments, "word_segments": word_segments}
-
-"""
-source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
-"""
-def get_trellis(emission, tokens, blank_id=0):
- num_frame = emission.size(0)
- num_tokens = len(tokens)
-
- # Trellis has extra diemsions for both time axis and tokens.
- # The extra dim for tokens represents (start-of-sentence)
- # The extra dim for time axis is for simplification of the code.
- trellis = torch.empty((num_frame + 1, num_tokens + 1))
- trellis[0, 0] = 0
- trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
- trellis[0, -num_tokens:] = -float("inf")
- trellis[-num_tokens:, 0] = float("inf")
-
- for t in range(num_frame):
- trellis[t + 1, 1:] = torch.maximum(
- # Score for staying at the same token
- trellis[t, 1:] + emission[t, blank_id],
- # Score for changing to the next token
- trellis[t, :-1] + emission[t, tokens],
- )
- return trellis
-
-@dataclass
-class Point:
- token_index: int
- time_index: int
- score: float
-
-def backtrack(trellis, emission, tokens, blank_id=0):
- # Note:
- # j and t are indices for trellis, which has extra dimensions
- # for time and tokens at the beginning.
- # When referring to time frame index `T` in trellis,
- # the corresponding index in emission is `T-1`.
- # Similarly, when referring to token index `J` in trellis,
- # the corresponding index in transcript is `J-1`.
- j = trellis.size(1) - 1
- t_start = torch.argmax(trellis[:, j]).item()
-
- path = []
- for t in range(t_start, 0, -1):
- # 1. Figure out if the current position was stay or change
- # Note (again):
- # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
- # Score for token staying the same from time frame J-1 to T.
- stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
- # Score for token changing from C-1 at T-1 to J at T.
- changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
-
- # 2. Store the path with frame-wise probability.
- prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
- # Return token index and time index in non-trellis coordinate.
- path.append(Point(j - 1, t - 1, prob))
-
- # 3. Update the token
- if changed > stayed:
- j -= 1
- if j == 0:
- break
- else:
- # failed
- return None
- return path[::-1]
-
-# Merge the labels
-@dataclass
-class Segment:
- label: str
- start: int
- end: int
- score: float
-
- def __repr__(self):
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
-
- @property
- def length(self):
- return self.end - self.start
-
-def merge_repeats(path, transcript):
- i1, i2 = 0, 0
- segments = []
- while i1 < len(path):
- while i2 < len(path) and path[i1].token_index == path[i2].token_index:
- i2 += 1
- score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
- segments.append(
- Segment(
- transcript[path[i1].token_index],
- path[i1].time_index,
- path[i2 - 1].time_index + 1,
- score,
- )
- )
- i1 = i2
- return segments
-
-def merge_words(segments, separator="|"):
- words = []
- i1, i2 = 0, 0
- while i1 < len(segments):
- if i2 >= len(segments) or segments[i2].label == separator:
- if i1 != i2:
- segs = segments[i1:i2]
- word = "".join([seg.label for seg in segs])
- score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
- words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
- i1 = i2 + 1
- i2 = i1
- else:
- i2 += 1
- return words
diff --git a/third_party/whisperX/whisperx/asr.py b/third_party/whisperX/whisperx/asr.py
deleted file mode 100644
index 0ccaf92b..00000000
--- a/third_party/whisperX/whisperx/asr.py
+++ /dev/null
@@ -1,357 +0,0 @@
-import os
-import warnings
-from typing import List, Union, Optional, NamedTuple
-
-import ctranslate2
-import faster_whisper
-import numpy as np
-import torch
-from transformers import Pipeline
-from transformers.pipelines.pt_utils import PipelineIterator
-
-from .audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
-from .vad import load_vad_model, merge_chunks
-from .types import TranscriptionResult, SingleSegment
-
-def find_numeral_symbol_tokens(tokenizer):
- numeral_symbol_tokens = []
- for i in range(tokenizer.eot):
- token = tokenizer.decode([i]).removeprefix(" ")
- has_numeral_symbol = any(c in "0123456789%$£" for c in token)
- if has_numeral_symbol:
- numeral_symbol_tokens.append(i)
- return numeral_symbol_tokens
-
-class WhisperModel(faster_whisper.WhisperModel):
- '''
- FasterWhisperModel provides batched inference for faster-whisper.
- Currently only works in non-timestamp mode and fixed prompt for all samples in batch.
- '''
-
- def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None):
- batch_size = features.shape[0]
- all_tokens = []
- prompt_reset_since = 0
- if options.initial_prompt is not None:
- initial_prompt = " " + options.initial_prompt.strip()
- initial_prompt_tokens = tokenizer.encode(initial_prompt)
- all_tokens.extend(initial_prompt_tokens)
- previous_tokens = all_tokens[prompt_reset_since:]
- prompt = self.get_prompt(
- tokenizer,
- previous_tokens,
- without_timestamps=options.without_timestamps,
- prefix=options.prefix,
- )
-
- encoder_output = self.encode(features)
-
- max_initial_timestamp_index = int(
- round(options.max_initial_timestamp / self.time_precision)
- )
-
- result = self.model.generate(
- encoder_output,
- [prompt] * batch_size,
- beam_size=options.beam_size,
- patience=options.patience,
- length_penalty=options.length_penalty,
- max_length=self.max_length,
- suppress_blank=options.suppress_blank,
- suppress_tokens=options.suppress_tokens,
- )
-
- tokens_batch = [x.sequences_ids[0] for x in result]
-
- def decode_batch(tokens: List[List[int]]) -> str:
- res = []
- for tk in tokens:
- res.append([token for token in tk if token < tokenizer.eot])
- # text_tokens = [token for token in tokens if token < self.eot]
- return tokenizer.tokenizer.decode_batch(res)
-
- text = decode_batch(tokens_batch)
-
- return text
-
- def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
- # When the model is running on multiple GPUs, the encoder output should be moved
- # to the CPU since we don't know which GPU will handle the next job.
- to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
- # unsqueeze if batch size = 1
- if len(features.shape) == 2:
- features = np.expand_dims(features, 0)
- features = faster_whisper.transcribe.get_ctranslate2_storage(features)
-
- return self.model.encode(features, to_cpu=to_cpu)
-
-class FasterWhisperPipeline(Pipeline):
- """
- Huggingface Pipeline wrapper for FasterWhisperModel.
- """
- # TODO:
- # - add support for timestamp mode
- # - add support for custom inference kwargs
-
- def __init__(
- self,
- model,
- vad,
- vad_params: dict,
- options : NamedTuple,
- tokenizer=None,
- device: Union[int, str, "torch.device"] = -1,
- framework = "pt",
- language : Optional[str] = None,
- suppress_numerals: bool = False,
- **kwargs
- ):
- self.model = model
- self.tokenizer = tokenizer
- self.options = options
- self.preset_language = language
- self.suppress_numerals = suppress_numerals
- self._batch_size = kwargs.pop("batch_size", None)
- self._num_workers = 1
- self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
- self.call_count = 0
- self.framework = framework
- if self.framework == "pt":
- if isinstance(device, torch.device):
- self.device = device
- elif isinstance(device, str):
- self.device = torch.device(device)
- elif device < 0:
- self.device = torch.device("cpu")
- else:
- self.device = torch.device(f"cuda:{device}")
- else:
- self.device = device
-
- super(Pipeline, self).__init__()
- self.vad_model = vad
- self._vad_params = vad_params
-
- def _sanitize_parameters(self, **kwargs):
- preprocess_kwargs = {}
- if "tokenizer" in kwargs:
- preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
- return preprocess_kwargs, {}, {}
-
- def preprocess(self, audio):
- audio = audio['inputs']
- model_n_mels = self.model.feat_kwargs.get("feature_size")
- features = log_mel_spectrogram(
- audio,
- n_mels=model_n_mels if model_n_mels is not None else 80,
- padding=N_SAMPLES - audio.shape[0],
- )
- return {'inputs': features}
-
- def _forward(self, model_inputs):
- outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
- return {'text': outputs}
-
- def postprocess(self, model_outputs):
- return model_outputs
-
- def get_iterator(
- self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params
- ):
- dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
- if "TOKENIZERS_PARALLELISM" not in os.environ:
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
- # TODO hack by collating feature_extractor and image_processor
-
- def stack(items):
- return {'inputs': torch.stack([x['inputs'] for x in items])}
- dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack)
- model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
- final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
- return final_iterator
-
- def transcribe(
- self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
- ) -> TranscriptionResult:
- if isinstance(audio, str):
- audio = load_audio(audio)
-
- def data(audio, segments):
- for seg in segments:
- f1 = int(seg['start'] * SAMPLE_RATE)
- f2 = int(seg['end'] * SAMPLE_RATE)
- # print(f2-f1)
- yield {'inputs': audio[f1:f2]}
-
- vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
- vad_segments = merge_chunks(
- vad_segments,
- chunk_size,
- onset=self._vad_params["vad_onset"],
- offset=self._vad_params["vad_offset"],
- )
- if self.tokenizer is None:
- language = language or self.detect_language(audio)
- task = task or "transcribe"
- self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
- self.model.model.is_multilingual, task=task,
- language=language)
- else:
- language = language or self.tokenizer.language_code
- task = task or self.tokenizer.task
- if task != self.tokenizer.task or language != self.tokenizer.language_code:
- self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
- self.model.model.is_multilingual, task=task,
- language=language)
-
- if self.suppress_numerals:
- previous_suppress_tokens = self.options.suppress_tokens
- numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
- print(f"Suppressing numeral and symbol tokens")
- new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
- new_suppressed_tokens = list(set(new_suppressed_tokens))
- self.options = self.options._replace(suppress_tokens=new_suppressed_tokens)
-
- segments: List[SingleSegment] = []
- batch_size = batch_size or self._batch_size
- total_segments = len(vad_segments)
- for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
- if print_progress:
- base_progress = ((idx + 1) / total_segments) * 100
- percent_complete = base_progress / 2 if combined_progress else base_progress
- print(f"Progress: {percent_complete:.2f}%...")
- text = out['text']
- if batch_size in [0, 1, None]:
- text = text[0]
- segments.append(
- {
- "text": text,
- "start": round(vad_segments[idx]['start'], 3),
- "end": round(vad_segments[idx]['end'], 3)
- }
- )
-
- # revert the tokenizer if multilingual inference is enabled
- if self.preset_language is None:
- self.tokenizer = None
-
- # revert suppressed tokens if suppress_numerals is enabled
- if self.suppress_numerals:
- self.options = self.options._replace(suppress_tokens=previous_suppress_tokens)
-
- return {"segments": segments, "language": language}
-
-
- def detect_language(self, audio: np.ndarray):
- if audio.shape[0] < N_SAMPLES:
- print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
- model_n_mels = self.model.feat_kwargs.get("feature_size")
- segment = log_mel_spectrogram(audio[: N_SAMPLES],
- n_mels=model_n_mels if model_n_mels is not None else 80,
- padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
- encoder_output = self.model.encode(segment)
- results = self.model.model.detect_language(encoder_output)
- language_token, language_probability = results[0][0]
- language = language_token[2:-2]
- print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
- return language
-
-def load_model(whisper_arch,
- device,
- device_index=0,
- compute_type="float16",
- asr_options=None,
- language : Optional[str] = None,
- vad_model=None,
- vad_options=None,
- model : Optional[WhisperModel] = None,
- task="transcribe",
- download_root=None,
- threads=4):
- '''Load a Whisper model for inference.
- Args:
- whisper_arch: str - The name of the Whisper model to load.
- device: str - The device to load the model on.
- compute_type: str - The compute type to use for the model.
- options: dict - A dictionary of options to use for the model.
- language: str - The language of the model. (use English for now)
- model: Optional[WhisperModel] - The WhisperModel instance to use.
- download_root: Optional[str] - The root directory to download the model to.
- threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
- Returns:
- A Whisper pipeline.
- '''
-
- if whisper_arch.endswith(".en"):
- language = "en"
-
- model = model or WhisperModel(whisper_arch,
- device=device,
- device_index=device_index,
- compute_type=compute_type,
- download_root=download_root,
- cpu_threads=threads)
- if language is not None:
- tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
- else:
- print("No language specified, language will be first be detected for each audio file (increases inference time).")
- tokenizer = None
-
- default_asr_options = {
- "beam_size": 5,
- "best_of": 5,
- "patience": 1,
- "length_penalty": 1,
- "repetition_penalty": 1,
- "no_repeat_ngram_size": 0,
- "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
- "compression_ratio_threshold": 2.4,
- "log_prob_threshold": -1.0,
- "no_speech_threshold": 0.6,
- "condition_on_previous_text": False,
- "prompt_reset_on_temperature": 0.5,
- "initial_prompt": None,
- "prefix": None,
- "suppress_blank": True,
- "suppress_tokens": [-1],
- "without_timestamps": True,
- "max_initial_timestamp": 0.0,
- "word_timestamps": False,
- "prepend_punctuations": "\"'“¿([{-",
- "append_punctuations": "\"'.。,,!!??::”)]}、",
- "suppress_numerals": False,
- "max_new_tokens": None,
- "clip_timestamps": None,
- "hallucination_silence_threshold": None,
- }
-
- if asr_options is not None:
- default_asr_options.update(asr_options)
-
- suppress_numerals = default_asr_options["suppress_numerals"]
- del default_asr_options["suppress_numerals"]
-
- default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)
-
- default_vad_options = {
- "vad_onset": 0.500,
- "vad_offset": 0.363
- }
-
- if vad_options is not None:
- default_vad_options.update(vad_options)
-
- if vad_model is not None:
- vad_model = vad_model
- else:
- vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)
-
- return FasterWhisperPipeline(
- model=model,
- vad=vad_model,
- options=default_asr_options,
- tokenizer=tokenizer,
- language=language,
- suppress_numerals=suppress_numerals,
- vad_params=default_vad_options,
- )
diff --git a/third_party/whisperX/whisperx/assets/mel_filters.npz b/third_party/whisperX/whisperx/assets/mel_filters.npz
deleted file mode 100644
index 28ea2690..00000000
Binary files a/third_party/whisperX/whisperx/assets/mel_filters.npz and /dev/null differ
diff --git a/third_party/whisperX/whisperx/audio.py b/third_party/whisperX/whisperx/audio.py
deleted file mode 100644
index db210fb9..00000000
--- a/third_party/whisperX/whisperx/audio.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import os
-import subprocess
-from functools import lru_cache
-from typing import Optional, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from .utils import exact_div
-
-# hard-coded audio hyperparameters
-SAMPLE_RATE = 16000
-N_FFT = 400
-HOP_LENGTH = 160
-CHUNK_LENGTH = 30
-N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
-N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
-
-N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
-FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
-TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
-
-
-def load_audio(file: str, sr: int = SAMPLE_RATE):
- """
- Open an audio file and read as mono waveform, resampling as necessary
-
- Parameters
- ----------
- file: str
- The audio file to open
-
- sr: int
- The sample rate to resample the audio if necessary
-
- Returns
- -------
- A NumPy array containing the audio waveform, in float32 dtype.
- """
- try:
- # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
- # Requires the ffmpeg CLI to be installed.
- cmd = [
- "ffmpeg",
- "-nostdin",
- "-threads",
- "0",
- "-i",
- file,
- "-f",
- "s16le",
- "-ac",
- "1",
- "-acodec",
- "pcm_s16le",
- "-ar",
- str(sr),
- "-",
- ]
- out = subprocess.run(cmd, capture_output=True, check=True).stdout
- except subprocess.CalledProcessError as e:
- raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
-
- return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
-
-
-def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
- """
- Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
- """
- if torch.is_tensor(array):
- if array.shape[axis] > length:
- array = array.index_select(
- dim=axis, index=torch.arange(length, device=array.device)
- )
-
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
- else:
- if array.shape[axis] > length:
- array = array.take(indices=range(length), axis=axis)
-
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = np.pad(array, pad_widths)
-
- return array
-
-
-@lru_cache(maxsize=None)
-def mel_filters(device, n_mels: int) -> torch.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
-
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- )
- """
- assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
- with np.load(
- os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
- ) as f:
- return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
-
-
-def log_mel_spectrogram(
- audio: Union[str, np.ndarray, torch.Tensor],
- n_mels: int,
- padding: int = 0,
- device: Optional[Union[str, torch.device]] = None,
-):
- """
- Compute the log-Mel spectrogram of
-
- Parameters
- ----------
- audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
- The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
-
- n_mels: int
- The number of Mel-frequency filters, only 80 is supported
-
- padding: int
- Number of zero samples to pad to the right
-
- device: Optional[Union[str, torch.device]]
- If given, the audio tensor is moved to this device before STFT
-
- Returns
- -------
- torch.Tensor, shape = (80, n_frames)
- A Tensor that contains the Mel spectrogram
- """
- if not torch.is_tensor(audio):
- if isinstance(audio, str):
- audio = load_audio(audio)
- audio = torch.from_numpy(audio)
-
- if device is not None:
- audio = audio.to(device)
- if padding > 0:
- audio = F.pad(audio, (0, padding))
- window = torch.hann_window(N_FFT).to(audio.device)
- stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
- magnitudes = stft[..., :-1].abs() ** 2
-
- filters = mel_filters(audio.device, n_mels)
- mel_spec = filters @ magnitudes
-
- log_spec = torch.clamp(mel_spec, min=1e-10).log10()
- log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- return log_spec
diff --git a/third_party/whisperX/whisperx/conjunctions.py b/third_party/whisperX/whisperx/conjunctions.py
deleted file mode 100644
index 5af3cff0..00000000
--- a/third_party/whisperX/whisperx/conjunctions.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# conjunctions.py
-
-conjunctions_by_language = {
- 'en': {'and', 'whether', 'or', 'as', 'but', 'so', 'for', 'nor', 'which', 'yet', 'although', 'since', 'unless', 'when', 'while', 'because', 'if', 'how', 'that', 'than', 'who', 'where', 'what', 'near', 'before', 'after', 'across', 'through', 'until', 'once', 'whereas', 'even', 'both', 'either', 'neither', 'though'},
- 'fr': {'et', 'ou', 'mais', 'parce', 'bien', 'pendant', 'quand', 'où', 'comme', 'si', 'que', 'avant', 'après', 'aussitôt', 'jusqu’à', 'à', 'malgré', 'donc', 'tant', 'puisque', 'ni', 'soit', 'bien', 'encore', 'dès', 'lorsque'},
- 'de': {'und', 'oder', 'aber', 'weil', 'obwohl', 'während', 'wenn', 'wo', 'wie', 'dass', 'bevor', 'nachdem', 'sobald', 'bis', 'außer', 'trotzdem', 'also', 'sowie', 'indem', 'weder', 'sowohl', 'zwar', 'jedoch'},
- 'es': {'y', 'o', 'pero', 'porque', 'aunque', 'sin', 'mientras', 'cuando', 'donde', 'como', 'si', 'que', 'antes', 'después', 'tan', 'hasta', 'a', 'a', 'por', 'ya', 'ni', 'sino'},
- 'it': {'e', 'o', 'ma', 'perché', 'anche', 'mentre', 'quando', 'dove', 'come', 'se', 'che', 'prima', 'dopo', 'appena', 'fino', 'a', 'nonostante', 'quindi', 'poiché', 'né', 'ossia', 'cioè'},
- 'ja': {'そして', 'または', 'しかし', 'なぜなら', 'もし', 'それとも', 'だから', 'それに', 'なのに', 'そのため', 'かつ', 'それゆえに', 'ならば', 'もしくは', 'ため'},
- 'zh': {'和', '或', '但是', '因为', '任何', '也', '虽然', '而且', '所以', '如果', '除非', '尽管', '既然', '即使', '只要', '直到', '然后', '因此', '不但', '而是', '不过'},
- 'nl': {'en', 'of', 'maar', 'omdat', 'hoewel', 'terwijl', 'wanneer', 'waar', 'zoals', 'als', 'dat', 'voordat', 'nadat', 'zodra', 'totdat', 'tenzij', 'ondanks', 'dus', 'zowel', 'noch', 'echter', 'toch'},
- 'uk': {'та', 'або', 'але', 'тому', 'хоча', 'поки', 'бо', 'коли', 'де', 'як', 'якщо', 'що', 'перш', 'після', 'доки', 'незважаючи', 'тому', 'ані'},
- 'pt': {'e', 'ou', 'mas', 'porque', 'embora', 'enquanto', 'quando', 'onde', 'como', 'se', 'que', 'antes', 'depois', 'assim', 'até', 'a', 'apesar', 'portanto', 'já', 'pois', 'nem', 'senão'},
- 'ar': {'و', 'أو', 'لكن', 'لأن', 'مع', 'بينما', 'عندما', 'حيث', 'كما', 'إذا', 'الذي', 'قبل', 'بعد', 'فور', 'حتى', 'إلا', 'رغم', 'لذلك', 'بما'},
- 'cs': {'a', 'nebo', 'ale', 'protože', 'ačkoli', 'zatímco', 'když', 'kde', 'jako', 'pokud', 'že', 'než', 'poté', 'jakmile', 'dokud', 'pokud ne', 'navzdory', 'tak', 'stejně', 'ani', 'tudíž'},
- 'ru': {'и', 'или', 'но', 'потому', 'хотя', 'пока', 'когда', 'где', 'как', 'если', 'что', 'перед', 'после', 'несмотря', 'таким', 'также', 'ни', 'зато'},
- 'pl': {'i', 'lub', 'ale', 'ponieważ', 'chociaż', 'podczas', 'kiedy', 'gdzie', 'jak', 'jeśli', 'że', 'zanim', 'po', 'jak tylko', 'dopóki', 'chyba', 'pomimo', 'więc', 'tak', 'ani', 'czyli'},
- 'hu': {'és', 'vagy', 'de', 'mert', 'habár', 'míg', 'amikor', 'ahol', 'ahogy', 'ha', 'hogy', 'mielőtt', 'miután', 'amint', 'amíg', 'hacsak', 'ellenére', 'tehát', 'úgy', 'sem', 'vagyis'},
- 'fi': {'ja', 'tai', 'mutta', 'koska', 'vaikka', 'kun', 'missä', 'kuten', 'jos', 'että', 'ennen', 'sen jälkeen', 'heti', 'kunnes', 'ellei', 'huolimatta', 'siis', 'sekä', 'eikä', 'vaan'},
- 'fa': {'و', 'یا', 'اما', 'چون', 'اگرچه', 'در حالی', 'وقتی', 'کجا', 'چگونه', 'اگر', 'که', 'قبل', 'پس', 'به محض', 'تا زمانی', 'مگر', 'با وجود', 'پس', 'همچنین', 'نه'},
- 'el': {'και', 'ή', 'αλλά', 'επειδή', 'αν', 'ενώ', 'όταν', 'όπου', 'όπως', 'αν', 'που', 'προτού', 'αφού', 'μόλις', 'μέχρι', 'εκτός', 'παρά', 'έτσι', 'όπως', 'ούτε', 'δηλαδή'},
- 'tr': {'ve', 'veya', 'ama', 'çünkü', 'her ne', 'iken', 'nerede', 'nasıl', 'eğer', 'ki', 'önce', 'sonra', 'hemen', 'kadar', 'rağmen', 'hem', 'ne', 'yani'},
- 'da': {'og', 'eller', 'men', 'fordi', 'selvom', 'mens', 'når', 'hvor', 'som', 'hvis', 'at', 'før', 'efter', 'indtil', 'medmindre', 'således', 'ligesom', 'hverken', 'altså'},
- 'he': {'ו', 'או', 'אבל', 'כי', 'אף', 'בזמן', 'כאשר', 'היכן', 'כיצד', 'אם', 'ש', 'לפני', 'אחרי', 'ברגע', 'עד', 'אלא', 'למרות', 'לכן', 'כמו', 'לא', 'אז'},
- 'vi': {'và', 'hoặc', 'nhưng', 'bởi', 'mặc', 'trong', 'khi', 'ở', 'như', 'nếu', 'rằng', 'trước', 'sau', 'ngay', 'cho', 'trừ', 'mặc', 'vì', 'giống', 'cũng', 'tức'},
- 'ko': {'그리고', '또는','그런데','그래도', '이나', '결국', '마지막으로', '마찬가지로', '반면에', '아니면', '거나', '또는', '그럼에도', '그렇기', '때문에', '덧붙이자면', '게다가', '그러나', '고', '그래서', '랑', '한다면', '하지만', '무엇', '왜냐하면', '비록', '동안', '언제', '어디서', '어떻게', '만약', '그', '전에', '후에', '즉시', '까지', '아니라면', '불구하고', '따라서', '같은', '도'},
- 'ur': {'اور', 'یا', 'مگر', 'کیونکہ', 'اگرچہ', 'جبکہ', 'جب', 'کہاں', 'کس طرح', 'اگر', 'کہ', 'سے پہلے', 'کے بعد', 'جیسے ہی', 'تک', 'اگر نہیں تو', 'کے باوجود', 'اس لئے', 'جیسے', 'نہ'},
- 'hi': {'और', 'या', 'पर', 'तो', 'न', 'फिर', 'हालांकि', 'चूंकि', 'अगर', 'कैसे', 'वह', 'से', 'जो', 'जहां', 'क्या', 'नजदीक', 'पहले', 'बाद', 'के', 'पार', 'माध्यम', 'तक', 'एक', 'जबकि', 'यहां', 'तक', 'दोनों', 'या', 'न', 'हालांकि'}
-
-}
-
-commas_by_language = {
- 'ja': '、',
- 'zh': ',',
- 'fa': '،',
- 'ur': '،'
-}
-
-def get_conjunctions(lang_code):
- return conjunctions_by_language.get(lang_code, set())
-
-def get_comma(lang_code):
- return commas_by_language.get(lang_code, ',')
\ No newline at end of file
diff --git a/third_party/whisperX/whisperx/diarize.py b/third_party/whisperX/whisperx/diarize.py
deleted file mode 100644
index c327c932..00000000
--- a/third_party/whisperX/whisperx/diarize.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import numpy as np
-import pandas as pd
-from pyannote.audio import Pipeline
-from typing import Optional, Union
-import torch
-
-from .audio import load_audio, SAMPLE_RATE
-
-
-class DiarizationPipeline:
- def __init__(
- self,
- model_name="pyannote/speaker-diarization-3.1",
- use_auth_token=None,
- device: Optional[Union[str, torch.device]] = "cpu",
- ):
- if isinstance(device, str):
- device = torch.device(device)
- self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
-
- def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
- if isinstance(audio, str):
- audio = load_audio(audio)
- audio_data = {
- 'waveform': torch.from_numpy(audio[None, :]),
- 'sample_rate': SAMPLE_RATE
- }
- segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
- diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
- diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
- diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
- return diarize_df
-
-
-def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
- transcript_segments = transcript_result["segments"]
- for seg in transcript_segments:
- # assign speaker to segment (if any)
- diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start'])
- diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
- # remove no hit, otherwise we look for closest (even negative intersection...)
- if not fill_nearest:
- dia_tmp = diarize_df[diarize_df['intersection'] > 0]
- else:
- dia_tmp = diarize_df
- if len(dia_tmp) > 0:
- # sum over speakers
- speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
- seg["speaker"] = speaker
-
- # assign speaker to words
- if 'words' in seg:
- for word in seg['words']:
- if 'start' in word:
- diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start'])
- diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start'])
- # remove no hit
- if not fill_nearest:
- dia_tmp = diarize_df[diarize_df['intersection'] > 0]
- else:
- dia_tmp = diarize_df
- if len(dia_tmp) > 0:
- # sum over speakers
- speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
- word["speaker"] = speaker
-
- return transcript_result
-
-
-class Segment:
- def __init__(self, start, end, speaker=None):
- self.start = start
- self.end = end
- self.speaker = speaker
diff --git a/third_party/whisperX/whisperx/transcribe.py b/third_party/whisperX/whisperx/transcribe.py
deleted file mode 100644
index edd27648..00000000
--- a/third_party/whisperX/whisperx/transcribe.py
+++ /dev/null
@@ -1,230 +0,0 @@
-import argparse
-import gc
-import os
-import warnings
-
-import numpy as np
-import torch
-
-from .alignment import align, load_align_model
-from .asr import load_model
-from .audio import load_audio
-from .diarize import DiarizationPipeline, assign_word_speakers
-from .utils import (LANGUAGES, TO_LANGUAGE_CODE, get_writer, optional_float,
- optional_int, str2bool)
-
-
-def cli():
- # fmt: off
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
- parser.add_argument("--model", default="small", help="name of the Whisper model to use")
- parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
- parser.add_argument("--device_index", default=0, type=int, help="device index to use for FasterWhisper inference")
- parser.add_argument("--batch_size", default=8, type=int, help="the preferred batch size for inference")
- parser.add_argument("--compute_type", default="float16", type=str, choices=["float16", "float32", "int8"], help="compute type for computation")
-
- parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "vtt", "txt", "tsv", "json", "aud"], help="format of the output file; if not specified, all available formats will be produced")
- parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
-
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
-
- # alignment params
- parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
- parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
- parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
- parser.add_argument("--return_char_alignments", action='store_true', help="Return character-level alignments in the output json file")
-
- # vad params
- parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
- parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
- parser.add_argument("--chunk_size", type=int, default=30, help="Chunk size for merging VAD segments. Default is 30, reduce this if the chunk is too long.")
-
- # diarization params
- parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
- parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
- parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
-
- parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
- parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
- parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
- parser.add_argument("--patience", type=float, default=1.0, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
- parser.add_argument("--length_penalty", type=float, default=1.0, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
-
- parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
- parser.add_argument("--suppress_numerals", action="store_true", help="whether to suppress numeric symbols and currency symbols during sampling, since wav2vec2 cannot align them correctly")
-
- parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
- parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
-
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
-
- parser.add_argument("--max_line_width", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
- parser.add_argument("--max_line_count", type=optional_int, default=None, help="(not possible with --no_align) the maximum number of lines in a segment")
- parser.add_argument("--highlight_words", type=str2bool, default=False, help="(not possible with --no_align) underline each word as it is spoken in srt and vtt")
- parser.add_argument("--segment_resolution", type=str, default="sentence", choices=["sentence", "chunk"], help="(not possible with --no_align) the maximum number of characters in a line before breaking the line")
-
- parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
-
- parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
-
- parser.add_argument("--print_progress", type=str2bool, default = False, help = "if True, progress will be printed in transcribe() and align() methods.")
- # fmt: on
-
- args = parser.parse_args().__dict__
- model_name: str = args.pop("model")
- batch_size: int = args.pop("batch_size")
- model_dir: str = args.pop("model_dir")
- output_dir: str = args.pop("output_dir")
- output_format: str = args.pop("output_format")
- device: str = args.pop("device")
- device_index: int = args.pop("device_index")
- compute_type: str = args.pop("compute_type")
-
- # model_flush: bool = args.pop("model_flush")
- os.makedirs(output_dir, exist_ok=True)
-
- align_model: str = args.pop("align_model")
- interpolate_method: str = args.pop("interpolate_method")
- no_align: bool = args.pop("no_align")
- task : str = args.pop("task")
- if task == "translate":
- # translation cannot be aligned
- no_align = True
-
- return_char_alignments: bool = args.pop("return_char_alignments")
-
- hf_token: str = args.pop("hf_token")
- vad_onset: float = args.pop("vad_onset")
- vad_offset: float = args.pop("vad_offset")
-
- chunk_size: int = args.pop("chunk_size")
-
- diarize: bool = args.pop("diarize")
- min_speakers: int = args.pop("min_speakers")
- max_speakers: int = args.pop("max_speakers")
- print_progress: bool = args.pop("print_progress")
-
- if args["language"] is not None:
- args["language"] = args["language"].lower()
- if args["language"] not in LANGUAGES:
- if args["language"] in TO_LANGUAGE_CODE:
- args["language"] = TO_LANGUAGE_CODE[args["language"]]
- else:
- raise ValueError(f"Unsupported language: {args['language']}")
-
- if model_name.endswith(".en") and args["language"] != "en":
- if args["language"] is not None:
- warnings.warn(
- f"{model_name} is an English-only model but received '{args['language']}'; using English instead."
- )
- args["language"] = "en"
- align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
-
- temperature = args.pop("temperature")
- if (increment := args.pop("temperature_increment_on_fallback")) is not None:
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
- else:
- temperature = [temperature]
-
- faster_whisper_threads = 4
- if (threads := args.pop("threads")) > 0:
- torch.set_num_threads(threads)
- faster_whisper_threads = threads
-
- asr_options = {
- "beam_size": args.pop("beam_size"),
- "patience": args.pop("patience"),
- "length_penalty": args.pop("length_penalty"),
- "temperatures": temperature,
- "compression_ratio_threshold": args.pop("compression_ratio_threshold"),
- "log_prob_threshold": args.pop("logprob_threshold"),
- "no_speech_threshold": args.pop("no_speech_threshold"),
- "condition_on_previous_text": False,
- "initial_prompt": args.pop("initial_prompt"),
- "suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")],
- "suppress_numerals": args.pop("suppress_numerals"),
- }
-
- writer = get_writer(output_format, output_dir)
- word_options = ["highlight_words", "max_line_count", "max_line_width"]
- if no_align:
- for option in word_options:
- if args[option]:
- parser.error(f"--{option} not possible with --no_align")
- if args["max_line_count"] and not args["max_line_width"]:
- warnings.warn("--max_line_count has no effect without --max_line_width")
- writer_args = {arg: args.pop(arg) for arg in word_options}
-
- # Part 1: VAD & ASR Loop
- results = []
- tmp_results = []
- # model = load_model(model_name, device=device, download_root=model_dir)
- model = load_model(model_name, device=device, device_index=device_index, download_root=model_dir, compute_type=compute_type, language=args['language'], asr_options=asr_options, vad_options={"vad_onset": vad_onset, "vad_offset": vad_offset}, task=task, threads=faster_whisper_threads)
-
- for audio_path in args.pop("audio"):
- audio = load_audio(audio_path)
- # >> VAD & ASR
- print(">>Performing transcription...")
- result = model.transcribe(audio, batch_size=batch_size, chunk_size=chunk_size, print_progress=print_progress)
- results.append((result, audio_path))
-
- # Unload Whisper and VAD
- del model
- gc.collect()
- torch.cuda.empty_cache()
-
- # Part 2: Align Loop
- if not no_align:
- tmp_results = results
- results = []
- align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
- for result, audio_path in tmp_results:
- # >> Align
- if len(tmp_results) > 1:
- input_audio = audio_path
- else:
- # lazily load audio from part 1
- input_audio = audio
-
- if align_model is not None and len(result["segments"]) > 0:
- if result.get("language", "en") != align_metadata["language"]:
- # load new language
- print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
- align_model, align_metadata = load_align_model(result["language"], device)
- print(">>Performing alignment...")
- result = align(result["segments"], align_model, align_metadata, input_audio, device, interpolate_method=interpolate_method, return_char_alignments=return_char_alignments, print_progress=print_progress)
-
- results.append((result, audio_path))
-
- # Unload align model
- del align_model
- gc.collect()
- torch.cuda.empty_cache()
-
- # >> Diarize
- if diarize:
- if hf_token is None:
- print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
- tmp_results = results
- print(">>Performing diarization...")
- results = []
- diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
- for result, input_audio_path in tmp_results:
- diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
- result = assign_word_speakers(diarize_segments, result)
- results.append((result, input_audio_path))
- # >> Write
- for result, audio_path in results:
- result["language"] = align_language
- writer(result, audio_path, writer_args)
-
-if __name__ == "__main__":
- cli()
diff --git a/third_party/whisperX/whisperx/types.py b/third_party/whisperX/whisperx/types.py
deleted file mode 100644
index 68f2d783..00000000
--- a/third_party/whisperX/whisperx/types.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from typing import TypedDict, Optional, List
-
-
-class SingleWordSegment(TypedDict):
- """
- A single word of a speech.
- """
- word: str
- start: float
- end: float
- score: float
-
-class SingleCharSegment(TypedDict):
- """
- A single char of a speech.
- """
- char: str
- start: float
- end: float
- score: float
-
-
-class SingleSegment(TypedDict):
- """
- A single segment (up to multiple sentences) of a speech.
- """
-
- start: float
- end: float
- text: str
-
-
-class SingleAlignedSegment(TypedDict):
- """
- A single segment (up to multiple sentences) of a speech with word alignment.
- """
-
- start: float
- end: float
- text: str
- words: List[SingleWordSegment]
- chars: Optional[List[SingleCharSegment]]
-
-
-class TranscriptionResult(TypedDict):
- """
- A list of segments and word segments of a speech.
- """
- segments: List[SingleSegment]
- language: str
-
-
-class AlignedTranscriptionResult(TypedDict):
- """
- A list of segments and word segments of a speech.
- """
- segments: List[SingleAlignedSegment]
- word_segments: List[SingleWordSegment]
diff --git a/third_party/whisperX/whisperx/utils.py b/third_party/whisperX/whisperx/utils.py
deleted file mode 100644
index 16ce116e..00000000
--- a/third_party/whisperX/whisperx/utils.py
+++ /dev/null
@@ -1,437 +0,0 @@
-import json
-import os
-import re
-import sys
-import zlib
-from typing import Callable, Optional, TextIO
-
-LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "he": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
- "yue": "cantonese",
-}
-
-# language code lookup by name, with a few language aliases
-TO_LANGUAGE_CODE = {
- **{language: code for code, language in LANGUAGES.items()},
- "burmese": "my",
- "valencian": "ca",
- "flemish": "nl",
- "haitian": "ht",
- "letzeburgesch": "lb",
- "pushto": "ps",
- "panjabi": "pa",
- "moldavian": "ro",
- "moldovan": "ro",
- "sinhalese": "si",
- "castilian": "es",
-}
-
-LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
-
-system_encoding = sys.getdefaultencoding()
-
-if system_encoding != "utf-8":
-
- def make_safe(string):
- # replaces any character not representable using the system default encoding with an '?',
- # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
- return string.encode(system_encoding, errors="replace").decode(system_encoding)
-
-else:
-
- def make_safe(string):
- # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
- return string
-
-
-def exact_div(x, y):
- assert x % y == 0
- return x // y
-
-
-def str2bool(string):
- str2val = {"True": True, "False": False}
- if string in str2val:
- return str2val[string]
- else:
- raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
-
-
-def optional_int(string):
- return None if string == "None" else int(string)
-
-
-def optional_float(string):
- return None if string == "None" else float(string)
-
-
-def compression_ratio(text) -> float:
- text_bytes = text.encode("utf-8")
- return len(text_bytes) / len(zlib.compress(text_bytes))
-
-
-def format_timestamp(
- seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
-):
- assert seconds >= 0, "non-negative timestamp expected"
- milliseconds = round(seconds * 1000.0)
-
- hours = milliseconds // 3_600_000
- milliseconds -= hours * 3_600_000
-
- minutes = milliseconds // 60_000
- milliseconds -= minutes * 60_000
-
- seconds = milliseconds // 1_000
- milliseconds -= seconds * 1_000
-
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
- return (
- f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
- )
-
-
-class ResultWriter:
- extension: str
-
- def __init__(self, output_dir: str):
- self.output_dir = output_dir
-
- def __call__(self, result: dict, audio_path: str, options: dict):
- audio_basename = os.path.basename(audio_path)
- audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(
- self.output_dir, audio_basename + "." + self.extension
- )
-
- with open(output_path, "w", encoding="utf-8") as f:
- self.write_result(result, file=f, options=options)
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- raise NotImplementedError
-
-
-class WriteTXT(ResultWriter):
- extension: str = "txt"
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- for segment in result["segments"]:
- print(segment["text"].strip(), file=file, flush=True)
-
-
-class SubtitlesWriter(ResultWriter):
- always_include_hours: bool
- decimal_marker: str
-
- def iterate_result(self, result: dict, options: dict):
- raw_max_line_width: Optional[int] = options["max_line_width"]
- max_line_count: Optional[int] = options["max_line_count"]
- highlight_words: bool = options["highlight_words"]
- max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width
- preserve_segments = max_line_count is None or raw_max_line_width is None
-
- if len(result["segments"]) == 0:
- return
-
- def iterate_subtitles():
- line_len = 0
- line_count = 1
- # the next subtitle to yield (a list of word timings with whitespace)
- subtitle: list[dict] = []
- times = []
- last = result["segments"][0]["start"]
- for segment in result["segments"]:
- for i, original_timing in enumerate(segment["words"]):
- timing = original_timing.copy()
- long_pause = not preserve_segments
- if "start" in timing:
- long_pause = long_pause and timing["start"] - last > 3.0
- else:
- long_pause = False
- has_room = line_len + len(timing["word"]) <= max_line_width
- seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
- if line_len > 0 and has_room and not long_pause and not seg_break:
- # line continuation
- line_len += len(timing["word"])
- else:
- # new line
- timing["word"] = timing["word"].strip()
- if (
- len(subtitle) > 0
- and max_line_count is not None
- and (long_pause or line_count >= max_line_count)
- or seg_break
- ):
- # subtitle break
- yield subtitle, times
- subtitle = []
- times = []
- line_count = 1
- elif line_len > 0:
- # line break
- line_count += 1
- timing["word"] = "\n" + timing["word"]
- line_len = len(timing["word"].strip())
- subtitle.append(timing)
- times.append((segment["start"], segment["end"], segment.get("speaker")))
- if "start" in timing:
- last = timing["start"]
- if len(subtitle) > 0:
- yield subtitle, times
-
- if "words" in result["segments"][0]:
- for subtitle, _ in iterate_subtitles():
- sstart, ssend, speaker = _[0]
- subtitle_start = self.format_timestamp(sstart)
- subtitle_end = self.format_timestamp(ssend)
- if result["language"] in LANGUAGES_WITHOUT_SPACES:
- subtitle_text = "".join([word["word"] for word in subtitle])
- else:
- subtitle_text = " ".join([word["word"] for word in subtitle])
- has_timing = any(["start" in word for word in subtitle])
-
- # add [$SPEAKER_ID]: to each subtitle if speaker is available
- prefix = ""
- if speaker is not None:
- prefix = f"[{speaker}]: "
-
- if highlight_words and has_timing:
- last = subtitle_start
- all_words = [timing["word"] for timing in subtitle]
- for i, this_word in enumerate(subtitle):
- if "start" in this_word:
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, prefix + subtitle_text
-
- yield start, end, prefix + " ".join(
- [
- re.sub(r"^(\s*)(.*)$", r"\1\2", word)
- if j == i
- else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
- else:
- yield subtitle_start, subtitle_end, prefix + subtitle_text
- else:
- for segment in result["segments"]:
- segment_start = self.format_timestamp(segment["start"])
- segment_end = self.format_timestamp(segment["end"])
- segment_text = segment["text"].strip().replace("-->", "->")
- if "speaker" in segment:
- segment_text = f"[{segment['speaker']}]: {segment_text}"
- yield segment_start, segment_end, segment_text
-
- def format_timestamp(self, seconds: float):
- return format_timestamp(
- seconds=seconds,
- always_include_hours=self.always_include_hours,
- decimal_marker=self.decimal_marker,
- )
-
-
-class WriteVTT(SubtitlesWriter):
- extension: str = "vtt"
- always_include_hours: bool = False
- decimal_marker: str = "."
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- print("WEBVTT\n", file=file)
- for start, end, text in self.iterate_result(result, options):
- print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
-
-
-class WriteSRT(SubtitlesWriter):
- extension: str = "srt"
- always_include_hours: bool = True
- decimal_marker: str = ","
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- for i, (start, end, text) in enumerate(
- self.iterate_result(result, options), start=1
- ):
- print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
-
-
-class WriteTSV(ResultWriter):
- """
- Write a transcript to a file in TSV (tab-separated values) format containing lines like:
- \t\t
-
- Using integer milliseconds as start and end times means there's no chance of interference from
- an environment setting a language encoding that causes the decimal in a floating point number
- to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
- """
-
- extension: str = "tsv"
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- print("start", "end", "text", sep="\t", file=file)
- for segment in result["segments"]:
- print(round(1000 * segment["start"]), file=file, end="\t")
- print(round(1000 * segment["end"]), file=file, end="\t")
- print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
-
-class WriteAudacity(ResultWriter):
- """
- Write a transcript to a text file that audacity can import as labels.
- The extension used is "aud" to distinguish it from the txt file produced by WriteTXT.
- Yet this is not an audacity project but only a label file!
-
- Please note : Audacity uses seconds in timestamps not ms!
- Also there is no header expected.
-
- If speaker is provided it is prepended to the text between double square brackets [[]].
- """
-
- extension: str = "aud"
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- ARROW = " "
- for segment in result["segments"]:
- print(segment["start"], file=file, end=ARROW)
- print(segment["end"], file=file, end=ARROW)
- print( ( ("[[" + segment["speaker"] + "]]") if "speaker" in segment else "") + segment["text"].strip().replace("\t", " "), file=file, flush=True)
-
-
-
-class WriteJSON(ResultWriter):
- extension: str = "json"
-
- def write_result(self, result: dict, file: TextIO, options: dict):
- json.dump(result, file, ensure_ascii=False)
-
-
-def get_writer(
- output_format: str, output_dir: str
-) -> Callable[[dict, TextIO, dict], None]:
- writers = {
- "txt": WriteTXT,
- "vtt": WriteVTT,
- "srt": WriteSRT,
- "tsv": WriteTSV,
- "json": WriteJSON,
- }
- optional_writers = {
- "aud": WriteAudacity,
- }
-
- if output_format == "all":
- all_writers = [writer(output_dir) for writer in writers.values()]
-
- def write_all(result: dict, file: TextIO, options: dict):
- for writer in all_writers:
- writer(result, file, options)
-
- return write_all
-
- if output_format in optional_writers:
- return optional_writers[output_format](output_dir)
- return writers[output_format](output_dir)
-
-def interpolate_nans(x, method='nearest'):
- if x.notnull().sum() > 1:
- return x.interpolate(method=method).ffill().bfill()
- else:
- return x.ffill().bfill()
diff --git a/third_party/whisperX/whisperx/vad.py b/third_party/whisperX/whisperx/vad.py
deleted file mode 100644
index ab2c7bbf..00000000
--- a/third_party/whisperX/whisperx/vad.py
+++ /dev/null
@@ -1,311 +0,0 @@
-import hashlib
-import os
-import urllib
-from typing import Callable, Optional, Text, Union
-
-import numpy as np
-import pandas as pd
-import torch
-from pyannote.audio import Model
-from pyannote.audio.core.io import AudioFile
-from pyannote.audio.pipelines import VoiceActivityDetection
-from pyannote.audio.pipelines.utils import PipelineModel
-from pyannote.core import Annotation, Segment, SlidingWindowFeature
-from tqdm import tqdm
-
-from .diarize import Segment as SegmentX
-
-VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
-
-def load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
- model_dir = torch.hub._get_torch_home()
- os.makedirs(model_dir, exist_ok = True)
- if model_fp is None:
- model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
- if os.path.exists(model_fp) and not os.path.isfile(model_fp):
- raise RuntimeError(f"{model_fp} exists and is not a regular file")
-
- if not os.path.isfile(model_fp):
- with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
- with tqdm(
- total=int(source.info().get("Content-Length")),
- ncols=80,
- unit="iB",
- unit_scale=True,
- unit_divisor=1024,
- ) as loop:
- while True:
- buffer = source.read(8192)
- if not buffer:
- break
-
- output.write(buffer)
- loop.update(len(buffer))
-
- model_bytes = open(model_fp, "rb").read()
- if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
- raise RuntimeError(
- "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
- )
-
- vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
- hyperparameters = {"onset": vad_onset,
- "offset": vad_offset,
- "min_duration_on": 0.1,
- "min_duration_off": 0.1}
- vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
- vad_pipeline.instantiate(hyperparameters)
-
- return vad_pipeline
-
-class Binarize:
- """Binarize detection scores using hysteresis thresholding, with min-cut operation
- to ensure not segments are longer than max_duration.
-
- Parameters
- ----------
- onset : float, optional
- Onset threshold. Defaults to 0.5.
- offset : float, optional
- Offset threshold. Defaults to `onset`.
- min_duration_on : float, optional
- Remove active regions shorter than that many seconds. Defaults to 0s.
- min_duration_off : float, optional
- Fill inactive regions shorter than that many seconds. Defaults to 0s.
- pad_onset : float, optional
- Extend active regions by moving their start time by that many seconds.
- Defaults to 0s.
- pad_offset : float, optional
- Extend active regions by moving their end time by that many seconds.
- Defaults to 0s.
- max_duration: float
- The maximum length of an active segment, divides segment at timestamp with lowest score.
- Reference
- ---------
- Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
- RNN-based Voice Activity Detection", InterSpeech 2015.
-
- Modified by Max Bain to include WhisperX's min-cut operation
- https://arxiv.org/abs/2303.00747
-
- Pyannote-audio
- """
-
- def __init__(
- self,
- onset: float = 0.5,
- offset: Optional[float] = None,
- min_duration_on: float = 0.0,
- min_duration_off: float = 0.0,
- pad_onset: float = 0.0,
- pad_offset: float = 0.0,
- max_duration: float = float('inf')
- ):
-
- super().__init__()
-
- self.onset = onset
- self.offset = offset or onset
-
- self.pad_onset = pad_onset
- self.pad_offset = pad_offset
-
- self.min_duration_on = min_duration_on
- self.min_duration_off = min_duration_off
-
- self.max_duration = max_duration
-
- def __call__(self, scores: SlidingWindowFeature) -> Annotation:
- """Binarize detection scores
- Parameters
- ----------
- scores : SlidingWindowFeature
- Detection scores.
- Returns
- -------
- active : Annotation
- Binarized scores.
- """
-
- num_frames, num_classes = scores.data.shape
- frames = scores.sliding_window
- timestamps = [frames[i].middle for i in range(num_frames)]
-
- # annotation meant to store 'active' regions
- active = Annotation()
- for k, k_scores in enumerate(scores.data.T):
-
- label = k if scores.labels is None else scores.labels[k]
-
- # initial state
- start = timestamps[0]
- is_active = k_scores[0] > self.onset
- curr_scores = [k_scores[0]]
- curr_timestamps = [start]
- t = start
- for t, y in zip(timestamps[1:], k_scores[1:]):
- # currently active
- if is_active:
- curr_duration = t - start
- if curr_duration > self.max_duration:
- search_after = len(curr_scores) // 2
- # divide segment
- min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
- min_score_t = curr_timestamps[min_score_div_idx]
- region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
- active[region, k] = label
- start = curr_timestamps[min_score_div_idx]
- curr_scores = curr_scores[min_score_div_idx+1:]
- curr_timestamps = curr_timestamps[min_score_div_idx+1:]
- # switching from active to inactive
- elif y < self.offset:
- region = Segment(start - self.pad_onset, t + self.pad_offset)
- active[region, k] = label
- start = t
- is_active = False
- curr_scores = []
- curr_timestamps = []
- curr_scores.append(y)
- curr_timestamps.append(t)
- # currently inactive
- else:
- # switching from inactive to active
- if y > self.onset:
- start = t
- is_active = True
-
- # if active at the end, add final region
- if is_active:
- region = Segment(start - self.pad_onset, t + self.pad_offset)
- active[region, k] = label
-
- # because of padding, some active regions might be overlapping: merge them.
- # also: fill same speaker gaps shorter than min_duration_off
- if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
- if self.max_duration < float("inf"):
- raise NotImplementedError(f"This would break current max_duration param")
- active = active.support(collar=self.min_duration_off)
-
- # remove tracks shorter than min_duration_on
- if self.min_duration_on > 0:
- for segment, track in list(active.itertracks()):
- if segment.duration < self.min_duration_on:
- del active[segment, track]
-
- return active
-
-
-class VoiceActivitySegmentation(VoiceActivityDetection):
- def __init__(
- self,
- segmentation: PipelineModel = "pyannote/segmentation",
- fscore: bool = False,
- use_auth_token: Union[Text, None] = None,
- **inference_kwargs,
- ):
-
- super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
-
- def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
- """Apply voice activity detection
-
- Parameters
- ----------
- file : AudioFile
- Processed file.
- hook : callable, optional
- Hook called after each major step of the pipeline with the following
- signature: hook("step_name", step_artefact, file=file)
-
- Returns
- -------
- speech : Annotation
- Speech regions.
- """
-
- # setup hook (e.g. for debugging purposes)
- hook = self.setup_hook(file, hook=hook)
-
- # apply segmentation model (only if needed)
- # output shape is (num_chunks, num_frames, 1)
- if self.training:
- if self.CACHED_SEGMENTATION in file:
- segmentations = file[self.CACHED_SEGMENTATION]
- else:
- segmentations = self._segmentation(file)
- file[self.CACHED_SEGMENTATION] = segmentations
- else:
- segmentations: SlidingWindowFeature = self._segmentation(file)
-
- return segmentations
-
-
-def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
-
- active = Annotation()
- for k, vad_t in enumerate(vad_arr):
- region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
- active[region, k] = 1
-
-
- if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
- active = active.support(collar=min_duration_off)
-
- # remove tracks shorter than min_duration_on
- if min_duration_on > 0:
- for segment, track in list(active.itertracks()):
- if segment.duration < min_duration_on:
- del active[segment, track]
-
- active = active.for_json()
- active_segs = pd.DataFrame([x['segment'] for x in active['content']])
- return active_segs
-
-def merge_chunks(
- segments,
- chunk_size,
- onset: float = 0.5,
- offset: Optional[float] = None,
-):
- """
- Merge operation described in paper
- """
- curr_end = 0
- merged_segments = []
- seg_idxs = []
- speaker_idxs = []
-
- assert chunk_size > 0
- binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset)
- segments = binarize(segments)
- segments_list = []
- for speech_turn in segments.get_timeline():
- segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
-
- if len(segments_list) == 0:
- print("No active speech found in audio")
- return []
- # assert segments_list, "segments_list is empty."
- # Make sur the starting point is the start of the segment.
- curr_start = segments_list[0].start
-
- for seg in segments_list:
- if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
- merged_segments.append({
- "start": curr_start,
- "end": curr_end,
- "segments": seg_idxs,
- })
- curr_start = seg.start
- seg_idxs = []
- speaker_idxs = []
- curr_end = seg.end
- seg_idxs.append((seg.start, seg.end))
- speaker_idxs.append(seg.speaker)
- # add final
- merged_segments.append({
- "start": curr_start,
- "end": curr_end,
- "segments": seg_idxs,
- })
- return merged_segments