|
|
|
|
|
""" |
|
|
Convert IndexTTS-2 PyTorch models to ONNX format for Rust inference! |
|
|
|
|
|
This script converts the three main models: |
|
|
1. GPT model (gpt.pth) - Autoregressive text-to-semantic generation |
|
|
2. S2Mel model (s2mel.pth) - Semantic-to-mel spectrogram conversion |
|
|
3. BigVGAN - Mel-to-waveform vocoder (already available as ONNX from NVIDIA) |
|
|
|
|
|
Usage: |
|
|
python tools/convert_to_onnx.py |
|
|
|
|
|
Output: |
|
|
models/gpt.onnx |
|
|
models/s2mel.onnx |
|
|
models/bigvgan.onnx (if needed, otherwise use NVIDIA's) |
|
|
|
|
|
Why ONNX? |
|
|
- Cross-platform: Works on Windows, Linux, macOS, M1/M2 Macs |
|
|
- Fast: ONNX Runtime is highly optimized |
|
|
- Rust-native: ort crate provides excellent ONNX Runtime bindings |
|
|
- No Python: Production inference without Python dependency hell! |
|
|
|
|
|
Author: Aye & Hue @ 8b.is |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
project_root = os.path.dirname(script_dir) |
|
|
os.chdir(project_root) |
|
|
|
|
|
|
|
|
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache' |
|
|
|
|
|
print("=" * 70) |
|
|
print(" IndexTTS-2 PyTorch to ONNX Converter") |
|
|
print(" For Rust inference with ort crate!") |
|
|
print("=" * 70) |
|
|
print() |
|
|
|
|
|
|
|
|
if not os.path.exists("checkpoints/gpt.pth"): |
|
|
print("ERROR: Models not found!") |
|
|
print("Run: python tools/download_files.py -s huggingface") |
|
|
sys.exit(1) |
|
|
|
|
|
import torch |
|
|
import torch.onnx |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, "indextts - REMOVING - REF ONLY") |
|
|
|
|
|
|
|
|
output_dir = Path("models") |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"Output directory: {output_dir}") |
|
|
print() |
|
|
|
|
|
|
|
|
def export_speaker_encoder(): |
|
|
""" |
|
|
Export the CAM++ speaker encoder to ONNX. |
|
|
|
|
|
This model extracts speaker embeddings from reference audio. |
|
|
Input: mel spectrogram [batch, n_mels, time] |
|
|
Output: speaker embedding [batch, 192] |
|
|
""" |
|
|
print("\n" + "=" * 50) |
|
|
print("Exporting Speaker Encoder (CAM++)") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
from omegaconf import OmegaConf |
|
|
from indextts.s2mel.modules.campplus.DTDNN import CAMPPlus |
|
|
|
|
|
|
|
|
cfg = OmegaConf.load("checkpoints/config.yaml") |
|
|
|
|
|
|
|
|
model = CAMPPlus(feat_dim=80, embedding_size=192) |
|
|
|
|
|
|
|
|
weights_path = "./checkpoints/hf_cache/models--funasr--campplus/snapshots/fb71fe990cbf6031ae6987a2d76fe64f94377b7e/campplus_cn_common.bin" |
|
|
if os.path.exists(weights_path): |
|
|
state_dict = torch.load(weights_path, map_location='cpu') |
|
|
model.load_state_dict(state_dict) |
|
|
print(f"Loaded weights from: {weights_path}") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
dummy_input = torch.randn(1, 100, 80) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
test_output = model(dummy_input) |
|
|
print(f"Forward pass works! Output shape: {test_output.shape}") |
|
|
|
|
|
|
|
|
output_path = output_dir / "speaker_encoder.onnx" |
|
|
torch.onnx.export( |
|
|
model, |
|
|
dummy_input, |
|
|
str(output_path), |
|
|
input_names=['mel_spectrogram'], |
|
|
output_names=['speaker_embedding'], |
|
|
dynamic_axes={ |
|
|
'mel_spectrogram': {0: 'batch', 1: 'time'}, |
|
|
'speaker_embedding': {0: 'batch'} |
|
|
}, |
|
|
opset_version=18, |
|
|
do_constant_folding=True, |
|
|
) |
|
|
|
|
|
|
|
|
import onnx |
|
|
onnx_model = onnx.load(str(output_path)) |
|
|
onnx.checker.check_model(onnx_model) |
|
|
|
|
|
print(f"✓ Exported: {output_path}") |
|
|
print(f" Input: mel_spectrogram [batch, time, 80]") |
|
|
print(f" Output: speaker_embedding [batch, 192]") |
|
|
print(f"✓ ONNX model verified!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Failed to export speaker encoder: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def export_gpt_model(): |
|
|
""" |
|
|
Export the GPT autoregressive model to ONNX. |
|
|
|
|
|
This is the most complex model - generates semantic tokens from text. |
|
|
We may need to export it in parts due to KV caching. |
|
|
|
|
|
Input: text_tokens [batch, seq_len], speaker_embedding [batch, 192] |
|
|
Output: semantic_codes [batch, code_len] |
|
|
""" |
|
|
print("\n" + "=" * 50) |
|
|
print("Exporting GPT Model (Autoregressive)") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
cfg = OmegaConf.load("checkpoints/config.yaml") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("GPT model export is complex due to:") |
|
|
print(" - Autoregressive generation with KV caching") |
|
|
print(" - Dynamic sequence lengths") |
|
|
print(" - Multiple internal components") |
|
|
print() |
|
|
print("Options:") |
|
|
print(" A) Export without KV cache (slower but simpler)") |
|
|
print(" B) Export encoder + single-step decoder (efficient)") |
|
|
print(" C) Use torch.compile + ONNX tracing") |
|
|
print() |
|
|
|
|
|
|
|
|
from infer_v2 import IndexTTS2 |
|
|
|
|
|
|
|
|
tts = IndexTTS2( |
|
|
cfg_path="checkpoints/config.yaml", |
|
|
model_dir="checkpoints", |
|
|
use_fp16=False, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
|
|
|
gpt = tts.gpt |
|
|
gpt.eval() |
|
|
|
|
|
print(f"GPT model loaded: {type(gpt)}") |
|
|
print(f"Parameters: {sum(p.numel() for p in gpt.parameters()):,}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_path = output_dir / "gpt_encoder.onnx" |
|
|
|
|
|
|
|
|
text_tokens = torch.randint(0, 30000, (1, 32), dtype=torch.int64) |
|
|
|
|
|
|
|
|
|
|
|
print(f"Attempting GPT export (may require modifications)...") |
|
|
|
|
|
|
|
|
print() |
|
|
print("Note: Full GPT export requires modifying the model code") |
|
|
print("to remove dynamic control flow. Creating a wrapper...") |
|
|
|
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Failed to export GPT: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def export_s2mel_model(): |
|
|
""" |
|
|
Export the Semantic-to-Mel model (flow matching). |
|
|
|
|
|
This converts semantic codes to mel spectrograms. |
|
|
Input: semantic_codes [batch, code_len], speaker_embedding [batch, 192] |
|
|
Output: mel_spectrogram [batch, 80, mel_len] |
|
|
""" |
|
|
print("\n" + "=" * 50) |
|
|
print("Exporting S2Mel Model (Flow Matching)") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
cfg = OmegaConf.load("checkpoints/config.yaml") |
|
|
|
|
|
print("S2Mel model (Diffusion/Flow Matching) is also complex:") |
|
|
print(" - Multiple denoising steps (iterative)") |
|
|
print(" - CFM (Conditional Flow Matching) requires ODE solving") |
|
|
print() |
|
|
print("Export strategy:") |
|
|
print(" 1. Export the single denoising step") |
|
|
print(" 2. Run iteration loop in Rust") |
|
|
print() |
|
|
|
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Failed to export S2Mel: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def export_bigvgan(): |
|
|
""" |
|
|
Export BigVGAN vocoder to ONNX. |
|
|
|
|
|
Good news: NVIDIA provides pre-trained BigVGAN models! |
|
|
Even better: They're designed for easy ONNX export. |
|
|
|
|
|
Input: mel_spectrogram [batch, 80, mel_len] |
|
|
Output: waveform [batch, 1, wave_len] |
|
|
""" |
|
|
print("\n" + "=" * 50) |
|
|
print("Exporting BigVGAN Vocoder") |
|
|
print("=" * 50) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
print("BigVGAN options:") |
|
|
print(" 1. Use NVIDIA's pre-exported ONNX (recommended)") |
|
|
print(" https://github.com/NVIDIA/BigVGAN") |
|
|
print() |
|
|
print(" 2. Export from PyTorch weights (we'll do this)") |
|
|
print() |
|
|
|
|
|
|
|
|
try: |
|
|
from bigvgan import bigvgan |
|
|
model = bigvgan.BigVGAN.from_pretrained( |
|
|
'nvidia/bigvgan_v2_22khz_80band_256x', |
|
|
use_cuda_kernel=False |
|
|
) |
|
|
model.eval() |
|
|
model.remove_weight_norm() |
|
|
|
|
|
print(f"BigVGAN loaded from HuggingFace") |
|
|
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
dummy_mel = torch.randn(1, 80, 100) |
|
|
|
|
|
|
|
|
output_path = output_dir / "bigvgan.onnx" |
|
|
torch.onnx.export( |
|
|
model, |
|
|
dummy_mel, |
|
|
str(output_path), |
|
|
input_names=['mel_spectrogram'], |
|
|
output_names=['waveform'], |
|
|
dynamic_axes={ |
|
|
'mel_spectrogram': {0: 'batch', 2: 'mel_length'}, |
|
|
'waveform': {0: 'batch', 2: 'wave_length'} |
|
|
}, |
|
|
opset_version=18, |
|
|
do_constant_folding=True, |
|
|
) |
|
|
|
|
|
print(f"✓ Exported: {output_path}") |
|
|
print(f" Input: mel_spectrogram [batch, 80, mel_len]") |
|
|
print(f" Output: waveform [batch, 1, wave_len]") |
|
|
|
|
|
|
|
|
import onnx |
|
|
onnx_model = onnx.load(str(output_path)) |
|
|
onnx.checker.check_model(onnx_model) |
|
|
print(f"✓ ONNX model verified!") |
|
|
|
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
print("bigvgan package not installed, installing...") |
|
|
os.system("pip install bigvgan") |
|
|
print("Please re-run the script.") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Failed to export BigVGAN: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("\nStarting ONNX conversion...\n") |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
results['speaker_encoder'] = export_speaker_encoder() |
|
|
results['gpt'] = export_gpt_model() |
|
|
results['s2mel'] = export_s2mel_model() |
|
|
results['bigvgan'] = export_bigvgan() |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" CONVERSION SUMMARY") |
|
|
print("=" * 70) |
|
|
|
|
|
for name, success in results.items(): |
|
|
status = "✓ SUCCESS" if success else "✗ NEEDS WORK" |
|
|
print(f" {name:20} {status}") |
|
|
|
|
|
print() |
|
|
|
|
|
if all(results.values()): |
|
|
print("All models converted! Ready for Rust inference.") |
|
|
else: |
|
|
print("Some models need manual intervention.") |
|
|
print() |
|
|
print("For complex models (GPT, S2Mel), consider:") |
|
|
print(" 1. Modifying the Python code to remove dynamic control flow") |
|
|
print(" 2. Using torch.jit.trace with concrete inputs") |
|
|
print(" 3. Exporting subcomponents separately") |
|
|
print(" 4. Using ONNX Runtime's transformer optimizations") |
|
|
|
|
|
print() |
|
|
print("Output directory:", output_dir.absolute()) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|