Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- __init__.py +20 -0
- assets.py +65 -0
- audio/__init__.py +13 -0
- audio/codec.py +58 -0
- audio/grid.py +79 -0
- cli.py +122 -0
- config.py +180 -0
- core/__init__.py +10 -0
- core/cache.py +106 -0
- core/depformer.py +264 -0
- core/layers.py +209 -0
- core/model.py +72 -0
- core/precision.py +23 -0
- core/transformer.py +140 -0
- engine.py +230 -0
- generation.py +158 -0
- runtime/__init__.py +7 -0
- runtime/audio_io.py +69 -0
- runtime/context.py +138 -0
- runtime/generator.py +420 -0
- runtime/guidance.py +38 -0
- runtime/logger.py +33 -0
- runtime/sampler.py +37 -0
- runtime/script_parser.py +69 -0
- runtime/state_machine.py +170 -0
- runtime/voice_clone.py +190 -0
__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import DiaConfig, load_config
|
| 2 |
+
from .core.model import Dia2Model
|
| 3 |
+
from .engine import Dia2
|
| 4 |
+
from .generation import (
|
| 5 |
+
GenerationConfig,
|
| 6 |
+
GenerationResult,
|
| 7 |
+
PrefixConfig,
|
| 8 |
+
SamplingConfig,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"DiaConfig",
|
| 13 |
+
"Dia2Model",
|
| 14 |
+
"load_config",
|
| 15 |
+
"GenerationConfig",
|
| 16 |
+
"GenerationResult",
|
| 17 |
+
"PrefixConfig",
|
| 18 |
+
"SamplingConfig",
|
| 19 |
+
"Dia2",
|
| 20 |
+
]
|
assets.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
ASSET_MANIFEST = os.environ.get("DIA2_ASSET_MANIFEST", "dia2_assets.json")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class AssetBundle:
|
| 16 |
+
config_path: str
|
| 17 |
+
weights_path: str
|
| 18 |
+
tokenizer_id: Optional[str]
|
| 19 |
+
mimi_id: Optional[str]
|
| 20 |
+
repo_id: Optional[str]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resolve_assets(
|
| 24 |
+
*,
|
| 25 |
+
repo: Optional[str],
|
| 26 |
+
config_path: Optional[str | Path],
|
| 27 |
+
weights_path: Optional[str | Path],
|
| 28 |
+
manifest_name: Optional[str] = None,
|
| 29 |
+
) -> AssetBundle:
|
| 30 |
+
repo_id = repo
|
| 31 |
+
manifest_name = manifest_name or ASSET_MANIFEST
|
| 32 |
+
if repo_id and (config_path or weights_path):
|
| 33 |
+
raise ValueError("Provide either repo or config+weights, not both")
|
| 34 |
+
if config_path is None or weights_path is None:
|
| 35 |
+
if repo_id is None:
|
| 36 |
+
raise ValueError("Must specify repo or config+weights")
|
| 37 |
+
manifest = load_manifest(repo_id, manifest_name)
|
| 38 |
+
config_name = manifest.get("config", "config.json")
|
| 39 |
+
weights_name = manifest.get("weights", "model.safetensors")
|
| 40 |
+
config_local = hf_hub_download(repo_id, config_name)
|
| 41 |
+
weights_local = hf_hub_download(repo_id, weights_name)
|
| 42 |
+
return AssetBundle(
|
| 43 |
+
config_path=config_local,
|
| 44 |
+
weights_path=weights_local,
|
| 45 |
+
tokenizer_id=manifest.get("tokenizer") or repo_id,
|
| 46 |
+
mimi_id=manifest.get("mimi"),
|
| 47 |
+
repo_id=repo_id,
|
| 48 |
+
)
|
| 49 |
+
return AssetBundle(str(config_path), str(weights_path), None, None, repo_id)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def load_manifest(repo_id: str, manifest_name: str) -> dict:
|
| 53 |
+
if not manifest_name:
|
| 54 |
+
return {}
|
| 55 |
+
try:
|
| 56 |
+
path = hf_hub_download(repo_id, manifest_name)
|
| 57 |
+
except Exception:
|
| 58 |
+
return {}
|
| 59 |
+
try:
|
| 60 |
+
return json.loads(Path(path).read_text())
|
| 61 |
+
except json.JSONDecodeError:
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
__all__ = ["AssetBundle", "ASSET_MANIFEST", "resolve_assets", "load_manifest"]
|
audio/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .codec import MimiCodec, DEFAULT_MIMI_MODEL_ID, MimiConfig
|
| 2 |
+
from .grid import delay_frames, undelay_frames, mask_audio_logits, fill_audio_channels, write_wav
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"MimiCodec",
|
| 6 |
+
"DEFAULT_MIMI_MODEL_ID",
|
| 7 |
+
"MimiConfig",
|
| 8 |
+
"delay_frames",
|
| 9 |
+
"undelay_frames",
|
| 10 |
+
"mask_audio_logits",
|
| 11 |
+
"fill_audio_channels",
|
| 12 |
+
"write_wav",
|
| 13 |
+
]
|
audio/codec.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from transformers import MimiModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class MimiConfig:
|
| 16 |
+
model_id: str = DEFAULT_MIMI_MODEL_ID
|
| 17 |
+
dtype: Optional[torch.dtype] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MimiCodec(nn.Module):
|
| 21 |
+
"""Thin wrapper around transformers' MimiModel for decoding audio tokens."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model: MimiModel, device: torch.device) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.model = model
|
| 26 |
+
self.device = device
|
| 27 |
+
cfg = getattr(model, "config", None)
|
| 28 |
+
self.sample_rate = getattr(cfg, "sampling_rate", 24000)
|
| 29 |
+
self.frame_rate = getattr(cfg, "frame_rate", 12.5)
|
| 30 |
+
self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def from_pretrained(
|
| 34 |
+
cls,
|
| 35 |
+
model_id: str = DEFAULT_MIMI_MODEL_ID,
|
| 36 |
+
*,
|
| 37 |
+
device: torch.device,
|
| 38 |
+
dtype: Optional[torch.dtype] = None,
|
| 39 |
+
) -> "MimiCodec":
|
| 40 |
+
model = MimiModel.from_pretrained(
|
| 41 |
+
model_id,
|
| 42 |
+
torch_dtype=dtype,
|
| 43 |
+
low_cpu_mem_usage=True,
|
| 44 |
+
)
|
| 45 |
+
model = model.to(device)
|
| 46 |
+
model.eval()
|
| 47 |
+
return cls(model, device)
|
| 48 |
+
|
| 49 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
codes = codes.to(self.device)
|
| 51 |
+
with torch.inference_mode():
|
| 52 |
+
audio, _ = self.model.decode(codes, return_dict=False)
|
| 53 |
+
return torch.clamp(audio, -1.0, 1.0)
|
| 54 |
+
|
| 55 |
+
def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
|
| 56 |
+
audio = audio.to(self.device)
|
| 57 |
+
with torch.inference_mode():
|
| 58 |
+
return self.model.encode(audio, return_dict=return_dict)
|
audio/grid.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Sequence
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
|
| 11 |
+
channels, total = aligned.shape
|
| 12 |
+
max_delay = max(delays) if delays else 0
|
| 13 |
+
out = aligned.new_full((channels, total + max_delay), pad_id)
|
| 14 |
+
for idx, delay in enumerate(delays):
|
| 15 |
+
out[idx, delay : delay + total] = aligned[idx]
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
|
| 20 |
+
channels, total = delayed.shape
|
| 21 |
+
max_delay = max(delays) if delays else 0
|
| 22 |
+
target = max(0, total - max_delay)
|
| 23 |
+
out = delayed.new_full((channels, target), pad_id)
|
| 24 |
+
for idx, delay in enumerate(delays):
|
| 25 |
+
out[idx] = delayed[idx, delay : delay + target]
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor:
|
| 30 |
+
if logits.shape[-1] == 0:
|
| 31 |
+
return logits
|
| 32 |
+
max_idx = logits.shape[-1] - 1
|
| 33 |
+
targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx]
|
| 34 |
+
if not targets:
|
| 35 |
+
return logits
|
| 36 |
+
masked = logits.clone()
|
| 37 |
+
neg_inf = torch.finfo(masked.dtype).min
|
| 38 |
+
for idx in targets:
|
| 39 |
+
masked[..., idx] = neg_inf
|
| 40 |
+
return masked
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def fill_audio_channels(
|
| 44 |
+
delays: Sequence[int],
|
| 45 |
+
constants,
|
| 46 |
+
step: int,
|
| 47 |
+
step_tokens: torch.Tensor,
|
| 48 |
+
audio_buf: torch.Tensor,
|
| 49 |
+
) -> None:
|
| 50 |
+
for cb, delay in enumerate(delays):
|
| 51 |
+
idx = step - delay
|
| 52 |
+
in_bounds = idx >= 0 and step < audio_buf.shape[-1]
|
| 53 |
+
if in_bounds:
|
| 54 |
+
step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step]
|
| 55 |
+
else:
|
| 56 |
+
step_tokens[:, 2 + cb, 0] = constants.audio_bos
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None:
|
| 60 |
+
path = Path(path)
|
| 61 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
audio = np.clip(audio, -1.0, 1.0)
|
| 63 |
+
pcm16 = (audio * 32767.0).astype(np.int16)
|
| 64 |
+
import wave
|
| 65 |
+
|
| 66 |
+
with wave.open(str(path), "wb") as handle:
|
| 67 |
+
handle.setnchannels(1)
|
| 68 |
+
handle.setsampwidth(2)
|
| 69 |
+
handle.setframerate(sample_rate)
|
| 70 |
+
handle.writeframes(pcm16.tobytes())
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
__all__ = [
|
| 74 |
+
"delay_frames",
|
| 75 |
+
"undelay_frames",
|
| 76 |
+
"mask_audio_logits",
|
| 77 |
+
"fill_audio_channels",
|
| 78 |
+
"write_wav",
|
| 79 |
+
]
|
cli.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .engine import Dia2
|
| 8 |
+
from .generation import (
|
| 9 |
+
build_generation_config,
|
| 10 |
+
load_script_text,
|
| 11 |
+
validate_generation_params,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Generate audio with Dia2")
|
| 17 |
+
parser.add_argument("--config", help="Path to config.json (overrides repo lookup)")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--weights", help="Path to model.safetensors (overrides repo lookup)"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--hf",
|
| 23 |
+
required=False,
|
| 24 |
+
help="Hugging Face repo id to download config/weights from (e.g. nari-labs/Dia2-2B)",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--input", default="input.txt", help="Script text file (default: input.txt)"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument("output", help="Output WAV path")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--device",
|
| 32 |
+
default=None,
|
| 33 |
+
help="Computation device (defaults to cuda if available, else cpu)",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--dtype",
|
| 37 |
+
choices=["auto", "float32", "bfloat16"],
|
| 38 |
+
default="bfloat16",
|
| 39 |
+
help="Computation dtype (default: bfloat16)",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--topk", type=int, default=50)
|
| 42 |
+
parser.add_argument("--temperature", type=float, default=0.8)
|
| 43 |
+
parser.add_argument("--cfg", type=float, default=1.0)
|
| 44 |
+
parser.add_argument("--tokenizer", help="Tokenizer repo or local path override")
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--mimi", help="Mimi repo id override (defaults to config/assets)"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--prefix-speaker-1", help="Prefix audio file for speaker 1")
|
| 49 |
+
parser.add_argument("--prefix-speaker-2", help="Prefix audio file for speaker 2")
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--include-prefix",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Keep prefix audio in the final waveform (default: trimmed)",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--verbose", action="store_true", help="Print generation progress logs"
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--cuda-graph",
|
| 60 |
+
action="store_true",
|
| 61 |
+
help="Run generation with CUDA graph capture",
|
| 62 |
+
)
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
|
| 65 |
+
device = args.device
|
| 66 |
+
if device is None or device == "auto":
|
| 67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
+
dtype = args.dtype or "bfloat16"
|
| 69 |
+
|
| 70 |
+
repo = args.hf
|
| 71 |
+
if repo:
|
| 72 |
+
dia = Dia2(
|
| 73 |
+
repo=repo,
|
| 74 |
+
device=device,
|
| 75 |
+
dtype=dtype,
|
| 76 |
+
tokenizer_id=args.tokenizer,
|
| 77 |
+
mimi_id=args.mimi,
|
| 78 |
+
)
|
| 79 |
+
elif args.config and args.weights:
|
| 80 |
+
dia = Dia2.from_local(
|
| 81 |
+
config_path=args.config,
|
| 82 |
+
weights_path=args.weights,
|
| 83 |
+
device=device,
|
| 84 |
+
dtype=dtype,
|
| 85 |
+
tokenizer_id=args.tokenizer,
|
| 86 |
+
mimi_id=args.mimi,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError("Provide --hf/--variant or both --config and --weights")
|
| 90 |
+
|
| 91 |
+
script = load_script_text(args.input)
|
| 92 |
+
temperature, top_k, cfg_scale = validate_generation_params(
|
| 93 |
+
temperature=args.temperature,
|
| 94 |
+
top_k=args.topk,
|
| 95 |
+
cfg_scale=args.cfg,
|
| 96 |
+
)
|
| 97 |
+
config = build_generation_config(
|
| 98 |
+
temperature=temperature,
|
| 99 |
+
top_k=top_k,
|
| 100 |
+
cfg_scale=cfg_scale,
|
| 101 |
+
)
|
| 102 |
+
overrides = {}
|
| 103 |
+
if args.cuda_graph:
|
| 104 |
+
overrides["use_cuda_graph"] = True
|
| 105 |
+
if args.prefix_speaker_1:
|
| 106 |
+
overrides["prefix_speaker_1"] = args.prefix_speaker_1
|
| 107 |
+
if args.prefix_speaker_2:
|
| 108 |
+
overrides["prefix_speaker_2"] = args.prefix_speaker_2
|
| 109 |
+
if args.include_prefix:
|
| 110 |
+
overrides["include_prefix"] = True
|
| 111 |
+
|
| 112 |
+
dia.generate(
|
| 113 |
+
script,
|
| 114 |
+
config=config,
|
| 115 |
+
output_wav=args.output,
|
| 116 |
+
verbose=args.verbose,
|
| 117 |
+
**overrides,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class DataConfig:
|
| 11 |
+
channels: int
|
| 12 |
+
text_vocab_size: int
|
| 13 |
+
audio_vocab_size: int
|
| 14 |
+
action_vocab_size: int
|
| 15 |
+
text_pad_token_id: int
|
| 16 |
+
text_new_word_token_id: int
|
| 17 |
+
text_zero_token_id: int
|
| 18 |
+
audio_pad_token_id: int
|
| 19 |
+
audio_bos_token_id: int
|
| 20 |
+
action_pad_token_id: int
|
| 21 |
+
action_new_word_token_id: int
|
| 22 |
+
delay_pattern: List[int]
|
| 23 |
+
first_word_min_start: int
|
| 24 |
+
max_pad: int
|
| 25 |
+
second_stream_ahead: int
|
| 26 |
+
tokenizer_path: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class DecoderConfig:
|
| 31 |
+
n_layer: int
|
| 32 |
+
n_embd: int
|
| 33 |
+
n_hidden: int
|
| 34 |
+
gqa_query_heads: int
|
| 35 |
+
kv_heads: int
|
| 36 |
+
gqa_head_dim: int
|
| 37 |
+
dropout: float
|
| 38 |
+
low_rank_dim: int | None = None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass(frozen=True)
|
| 42 |
+
class DepformerConfig:
|
| 43 |
+
n_layer: int
|
| 44 |
+
n_embd: int
|
| 45 |
+
n_hidden: int
|
| 46 |
+
gqa_query_heads: int
|
| 47 |
+
kv_heads: int
|
| 48 |
+
gqa_head_dim: int
|
| 49 |
+
apply_rope: bool
|
| 50 |
+
text_embedding: bool
|
| 51 |
+
mlp_activations: List[str]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass(frozen=True)
|
| 55 |
+
class LinearHeadConfig:
|
| 56 |
+
mlp_activations: List[str]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass(frozen=True)
|
| 60 |
+
class ModelConfig:
|
| 61 |
+
decoder: DecoderConfig
|
| 62 |
+
depformer: DepformerConfig
|
| 63 |
+
linear: LinearHeadConfig
|
| 64 |
+
dropout: float
|
| 65 |
+
rope_min_timescale: int
|
| 66 |
+
rope_max_timescale: int
|
| 67 |
+
normalization_layer_epsilon: float
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass(frozen=True)
|
| 71 |
+
class RuntimeConfig:
|
| 72 |
+
weights_schedule: List[int]
|
| 73 |
+
max_context_steps: int
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass(frozen=True)
|
| 77 |
+
class AssetsConfig:
|
| 78 |
+
tokenizer: Optional[str]
|
| 79 |
+
mimi: Optional[str]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass(frozen=True)
|
| 83 |
+
class DiaConfig:
|
| 84 |
+
data: DataConfig
|
| 85 |
+
model: ModelConfig
|
| 86 |
+
runtime: RuntimeConfig
|
| 87 |
+
assets: AssetsConfig
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _resolve_runtime(block: dict | None, data_cfg: DataConfig) -> RuntimeConfig:
|
| 91 |
+
block = block or {}
|
| 92 |
+
weights_schedule = block.get("weights_schedule")
|
| 93 |
+
if weights_schedule is None:
|
| 94 |
+
audio_channels = max(0, data_cfg.channels - 2)
|
| 95 |
+
weights_schedule = list(range(max(audio_channels - 1, 0)))
|
| 96 |
+
max_context = block.get("max_context_steps", 1500)
|
| 97 |
+
return RuntimeConfig(
|
| 98 |
+
weights_schedule=list(weights_schedule),
|
| 99 |
+
max_context_steps=int(max_context),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def load_config(path: str | Path) -> DiaConfig:
|
| 104 |
+
cfg = json.loads(Path(path).read_text())
|
| 105 |
+
data = cfg["data"]
|
| 106 |
+
model = cfg["model"]
|
| 107 |
+
runtime_cfg_raw = cfg.get("runtime")
|
| 108 |
+
if runtime_cfg_raw is None:
|
| 109 |
+
raise ValueError(f"Config '{path}' is missing a runtime block")
|
| 110 |
+
|
| 111 |
+
decoder_cfg = DecoderConfig(
|
| 112 |
+
n_layer=model["decoder"]["n_layer"],
|
| 113 |
+
n_embd=model["decoder"]["n_embd"],
|
| 114 |
+
n_hidden=model["decoder"]["n_hidden"],
|
| 115 |
+
gqa_query_heads=model["decoder"]["gqa_query_heads"],
|
| 116 |
+
kv_heads=model["decoder"]["kv_heads"],
|
| 117 |
+
gqa_head_dim=model["decoder"]["gqa_head_dim"],
|
| 118 |
+
dropout=model.get("dropout", 0.0),
|
| 119 |
+
low_rank_dim=model["decoder"].get("low_rank_dim"),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
depformer_cfg = DepformerConfig(
|
| 123 |
+
n_layer=model["depformer"]["n_layer"],
|
| 124 |
+
n_embd=model["depformer"]["n_embd"],
|
| 125 |
+
n_hidden=model["depformer"]["n_hidden"],
|
| 126 |
+
gqa_query_heads=model["depformer"]["gqa_query_heads"],
|
| 127 |
+
kv_heads=model["depformer"]["kv_heads"],
|
| 128 |
+
gqa_head_dim=model["depformer"]["gqa_head_dim"],
|
| 129 |
+
apply_rope=model["depformer"].get("apply_rope", True),
|
| 130 |
+
text_embedding=model["depformer"].get("text_embedding", True),
|
| 131 |
+
mlp_activations=model["depformer"].get("mlp_activations", ["silu", "linear"]),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
data_cfg = DataConfig(
|
| 135 |
+
channels=data["channels"],
|
| 136 |
+
text_vocab_size=data["text_vocab_size"],
|
| 137 |
+
audio_vocab_size=data["audio_vocab_size"],
|
| 138 |
+
action_vocab_size=data["action_vocab_size"],
|
| 139 |
+
text_pad_token_id=data["text_pad_token_id"],
|
| 140 |
+
text_new_word_token_id=data["text_new_word_token_id"],
|
| 141 |
+
text_zero_token_id=data.get("text_zero_token_id", 7),
|
| 142 |
+
audio_pad_token_id=data.get("audio_pad_token_id", data["audio_vocab_size"] - 1),
|
| 143 |
+
audio_bos_token_id=data.get("audio_bos_token_id", data["audio_vocab_size"] - 2),
|
| 144 |
+
action_pad_token_id=data["action_pad_token_id"],
|
| 145 |
+
action_new_word_token_id=data["action_new_word_token_id"],
|
| 146 |
+
delay_pattern=list(data.get("delay_pattern", [])),
|
| 147 |
+
first_word_min_start=data.get("first_word_min_start", 0),
|
| 148 |
+
max_pad=data.get("max_pad", 0),
|
| 149 |
+
second_stream_ahead=data.get("second_stream_ahead", 0),
|
| 150 |
+
tokenizer_path=data.get("tokenizer_path"),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
runtime_cfg = _resolve_runtime(runtime_cfg_raw, data_cfg)
|
| 154 |
+
|
| 155 |
+
linear_cfg = LinearHeadConfig(
|
| 156 |
+
mlp_activations=model.get("linear", {}).get("mlp_activations", ["silu", "linear"]),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
model_cfg = ModelConfig(
|
| 160 |
+
decoder=decoder_cfg,
|
| 161 |
+
depformer=depformer_cfg,
|
| 162 |
+
linear=linear_cfg,
|
| 163 |
+
dropout=model.get("dropout", 0.0),
|
| 164 |
+
rope_min_timescale=model.get("rope_min_timescale", 1),
|
| 165 |
+
rope_max_timescale=model.get("rope_max_timescale", 10000),
|
| 166 |
+
normalization_layer_epsilon=model.get("normalization_layer_epsilon", 1e-5),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
assets_raw = cfg.get("assets") or {}
|
| 170 |
+
assets_cfg = AssetsConfig(
|
| 171 |
+
tokenizer=assets_raw.get("tokenizer") or data_cfg.tokenizer_path,
|
| 172 |
+
mimi=assets_raw.get("mimi"),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return DiaConfig(
|
| 176 |
+
data=data_cfg,
|
| 177 |
+
model=model_cfg,
|
| 178 |
+
runtime=runtime_cfg,
|
| 179 |
+
assets=assets_cfg,
|
| 180 |
+
)
|
core/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import Dia2Model, DecodeState
|
| 2 |
+
from .transformer import TransformerDecoder
|
| 3 |
+
from .depformer import Depformer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Dia2Model",
|
| 7 |
+
"DecodeState",
|
| 8 |
+
"TransformerDecoder",
|
| 9 |
+
"Depformer",
|
| 10 |
+
]
|
core/cache.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class CacheSlot:
|
| 11 |
+
keys: torch.Tensor
|
| 12 |
+
values: torch.Tensor
|
| 13 |
+
|
| 14 |
+
def __post_init__(self) -> None:
|
| 15 |
+
self.max_steps = self.keys.shape[2]
|
| 16 |
+
self.head_dim = self.keys.shape[3]
|
| 17 |
+
self.flat_heads = self.keys.shape[0] * self.keys.shape[1]
|
| 18 |
+
device = self.keys.device
|
| 19 |
+
self.length = torch.zeros((), dtype=torch.long, device=device)
|
| 20 |
+
self.positions = torch.arange(self.max_steps, dtype=torch.long, device=device)
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def allocate(
|
| 24 |
+
cls,
|
| 25 |
+
*,
|
| 26 |
+
batch_size: int,
|
| 27 |
+
heads: int,
|
| 28 |
+
max_steps: int,
|
| 29 |
+
head_dim: int,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
dtype: torch.dtype,
|
| 32 |
+
) -> "CacheSlot":
|
| 33 |
+
keys = torch.zeros(batch_size, heads, max_steps, head_dim, device=device, dtype=dtype)
|
| 34 |
+
values = torch.zeros_like(keys)
|
| 35 |
+
return cls(keys, values)
|
| 36 |
+
|
| 37 |
+
def reset(self) -> None:
|
| 38 |
+
self.length.zero_()
|
| 39 |
+
|
| 40 |
+
def write_and_view(
|
| 41 |
+
self,
|
| 42 |
+
key_chunk: torch.Tensor,
|
| 43 |
+
value_chunk: torch.Tensor,
|
| 44 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 45 |
+
step = key_chunk.shape[2]
|
| 46 |
+
start = self.length
|
| 47 |
+
indices = self.positions[:step] + start
|
| 48 |
+
expanded = indices.unsqueeze(0).expand(self.flat_heads, -1)
|
| 49 |
+
|
| 50 |
+
flat_keys = self.keys.view(self.flat_heads, self.max_steps, self.head_dim)
|
| 51 |
+
flat_values = self.values.view(self.flat_heads, self.max_steps, self.head_dim)
|
| 52 |
+
flat_key_chunk = key_chunk.reshape(self.flat_heads, step, self.head_dim)
|
| 53 |
+
flat_value_chunk = value_chunk.reshape(self.flat_heads, step, self.head_dim)
|
| 54 |
+
scatter_index = expanded.unsqueeze(-1).expand_as(flat_key_chunk)
|
| 55 |
+
flat_keys.scatter_(1, scatter_index, flat_key_chunk)
|
| 56 |
+
flat_values.scatter_(1, scatter_index, flat_value_chunk)
|
| 57 |
+
|
| 58 |
+
self.length.add_(step)
|
| 59 |
+
bool_mask = (self.positions >= self.length).view(1, 1, 1, self.max_steps)
|
| 60 |
+
mask_dtype = self.keys.dtype
|
| 61 |
+
mask_value = torch.finfo(mask_dtype).min
|
| 62 |
+
attn_mask = torch.zeros_like(bool_mask, dtype=mask_dtype)
|
| 63 |
+
attn_mask = attn_mask.masked_fill(bool_mask, mask_value)
|
| 64 |
+
return self.keys, self.values, attn_mask
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class KVCache:
|
| 68 |
+
def __init__(self, slots: List[CacheSlot]) -> None:
|
| 69 |
+
self.slots = slots
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def allocate(
|
| 73 |
+
cls,
|
| 74 |
+
*,
|
| 75 |
+
num_layers: int,
|
| 76 |
+
batch_size: int,
|
| 77 |
+
heads: int,
|
| 78 |
+
max_steps: int,
|
| 79 |
+
head_dim: int,
|
| 80 |
+
device: torch.device,
|
| 81 |
+
dtype: torch.dtype,
|
| 82 |
+
) -> "KVCache":
|
| 83 |
+
slots = [
|
| 84 |
+
CacheSlot.allocate(
|
| 85 |
+
batch_size=batch_size,
|
| 86 |
+
heads=heads,
|
| 87 |
+
max_steps=max_steps,
|
| 88 |
+
head_dim=head_dim,
|
| 89 |
+
device=device,
|
| 90 |
+
dtype=dtype,
|
| 91 |
+
)
|
| 92 |
+
for _ in range(num_layers)
|
| 93 |
+
]
|
| 94 |
+
return cls(slots)
|
| 95 |
+
|
| 96 |
+
def get_slot(self, index: int) -> CacheSlot:
|
| 97 |
+
return self.slots[index]
|
| 98 |
+
|
| 99 |
+
def reset(self) -> None:
|
| 100 |
+
for slot in self.slots:
|
| 101 |
+
slot.reset()
|
| 102 |
+
|
| 103 |
+
clear = reset
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
__all__ = ["CacheSlot", "KVCache"]
|
core/depformer.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ..config import DiaConfig
|
| 10 |
+
from .cache import KVCache
|
| 11 |
+
from .layers import MultiStreamEmbedding, Mlp, RotaryEmbedding
|
| 12 |
+
from .precision import Precision
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ScheduleAttention(nn.Module):
|
| 16 |
+
"""Depformer attention that mirrors dia_v2 ScheduleAttention."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
dep_cfg = config.model.depformer
|
| 21 |
+
runtime = config.runtime
|
| 22 |
+
self.schedule = runtime.weights_schedule
|
| 23 |
+
self.num_query_heads = dep_cfg.gqa_query_heads
|
| 24 |
+
self.num_kv_heads = dep_cfg.kv_heads
|
| 25 |
+
self.head_dim = dep_cfg.gqa_head_dim
|
| 26 |
+
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
|
| 27 |
+
self.apply_rope = dep_cfg.apply_rope
|
| 28 |
+
self.used_ids = sorted(set(self.schedule))
|
| 29 |
+
self.compute_dtype = compute_dtype
|
| 30 |
+
|
| 31 |
+
self.in_proj = nn.ModuleDict(
|
| 32 |
+
{
|
| 33 |
+
str(i): nn.Linear(
|
| 34 |
+
dep_cfg.n_embd,
|
| 35 |
+
3 * self.num_query_heads * self.head_dim,
|
| 36 |
+
bias=False,
|
| 37 |
+
)
|
| 38 |
+
for i in self.used_ids
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
self.out_proj = nn.ModuleDict(
|
| 42 |
+
{
|
| 43 |
+
str(i): nn.Linear(
|
| 44 |
+
self.num_query_heads * self.head_dim,
|
| 45 |
+
dep_cfg.n_embd,
|
| 46 |
+
bias=False,
|
| 47 |
+
)
|
| 48 |
+
for i in self.used_ids
|
| 49 |
+
}
|
| 50 |
+
)
|
| 51 |
+
eps = config.model.normalization_layer_epsilon
|
| 52 |
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 53 |
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 54 |
+
|
| 55 |
+
if self.apply_rope:
|
| 56 |
+
self.rotary = RotaryEmbedding(
|
| 57 |
+
self.head_dim,
|
| 58 |
+
config.model.rope_min_timescale,
|
| 59 |
+
config.model.rope_max_timescale,
|
| 60 |
+
)
|
| 61 |
+
stage_count = max(len(self.schedule), 1)
|
| 62 |
+
self.register_buffer(
|
| 63 |
+
"stage_positions",
|
| 64 |
+
torch.arange(stage_count, dtype=torch.long).view(stage_count, 1),
|
| 65 |
+
persistent=False,
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.rotary = None
|
| 69 |
+
self.register_buffer(
|
| 70 |
+
"stage_positions",
|
| 71 |
+
torch.zeros(0, 1, dtype=torch.long),
|
| 72 |
+
persistent=False,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward_incremental(
|
| 76 |
+
self,
|
| 77 |
+
x_t: torch.Tensor,
|
| 78 |
+
stage_index: int,
|
| 79 |
+
cache_slot,
|
| 80 |
+
) -> Tuple[torch.Tensor, object]:
|
| 81 |
+
bsz, seq, _ = x_t.shape
|
| 82 |
+
if seq != 1:
|
| 83 |
+
raise ValueError("ScheduleAttention expects seq len 1 during decoding")
|
| 84 |
+
orig_dtype = x_t.dtype
|
| 85 |
+
module_index = self.schedule[stage_index]
|
| 86 |
+
proj = self.in_proj[str(module_index)](x_t.to(torch.float32))
|
| 87 |
+
proj = proj.view(bsz, seq, 3, self.num_query_heads, self.head_dim).to(self.compute_dtype)
|
| 88 |
+
|
| 89 |
+
q_proj = self.q_norm(proj[:, :, 0])
|
| 90 |
+
k_proj = self.k_norm(proj[:, :, 1])
|
| 91 |
+
v_proj = proj[:, :, 2]
|
| 92 |
+
|
| 93 |
+
if self.apply_rope:
|
| 94 |
+
pos_ids = self.stage_positions[stage_index : stage_index + 1]
|
| 95 |
+
if pos_ids.device != x_t.device:
|
| 96 |
+
pos_ids = pos_ids.to(x_t.device)
|
| 97 |
+
q_proj = self.rotary(q_proj, pos_ids)
|
| 98 |
+
k_proj = self.rotary(k_proj, pos_ids)
|
| 99 |
+
|
| 100 |
+
q = q_proj.transpose(1, 2)
|
| 101 |
+
k = k_proj.transpose(1, 2)
|
| 102 |
+
v = v_proj.transpose(1, 2)
|
| 103 |
+
|
| 104 |
+
if cache_slot is not None:
|
| 105 |
+
k, v, attn_mask = cache_slot.write_and_view(k, v)
|
| 106 |
+
else:
|
| 107 |
+
attn_mask = None
|
| 108 |
+
|
| 109 |
+
attn = F.scaled_dot_product_attention(
|
| 110 |
+
q,
|
| 111 |
+
k,
|
| 112 |
+
v,
|
| 113 |
+
scale=1.0,
|
| 114 |
+
attn_mask=attn_mask,
|
| 115 |
+
enable_gqa=self.num_gqa_groups > 1,
|
| 116 |
+
)
|
| 117 |
+
attn = attn.transpose(1, 2).contiguous()
|
| 118 |
+
flat = attn.reshape(bsz, seq, self.num_query_heads * self.head_dim)
|
| 119 |
+
out = self.out_proj[str(module_index)](flat.to(torch.float32))
|
| 120 |
+
return out.to(orig_dtype), cache_slot
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DepformerLayer(nn.Module):
|
| 124 |
+
def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
|
| 125 |
+
super().__init__()
|
| 126 |
+
dep_cfg = config.model.depformer
|
| 127 |
+
eps = config.model.normalization_layer_epsilon
|
| 128 |
+
self.pre_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
|
| 129 |
+
self.post_norm = nn.RMSNorm(dep_cfg.n_embd, eps=eps, dtype=torch.float32)
|
| 130 |
+
self.self_attention = ScheduleAttention(config, compute_dtype)
|
| 131 |
+
self.mlp = Mlp(
|
| 132 |
+
dep_cfg.n_embd,
|
| 133 |
+
dep_cfg.n_hidden,
|
| 134 |
+
compute_dtype,
|
| 135 |
+
tuple(config.model.depformer.mlp_activations),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def decode_step(
|
| 139 |
+
self,
|
| 140 |
+
x_t: torch.Tensor,
|
| 141 |
+
stage_index: int,
|
| 142 |
+
cache_slot,
|
| 143 |
+
) -> Tuple[torch.Tensor, object]:
|
| 144 |
+
residual = x_t
|
| 145 |
+
x_norm = self.pre_norm(x_t)
|
| 146 |
+
sa_out, _ = self.self_attention.forward_incremental(x_norm, stage_index, cache_slot)
|
| 147 |
+
x = residual + sa_out
|
| 148 |
+
residual2 = x
|
| 149 |
+
x_norm2 = self.post_norm(x)
|
| 150 |
+
mlp_out = self.mlp(x_norm2)
|
| 151 |
+
return residual2 + mlp_out, cache_slot
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Depformer(nn.Module):
|
| 155 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.config = config
|
| 158 |
+
self.precision = precision
|
| 159 |
+
dep_cfg = config.model.depformer
|
| 160 |
+
data_cfg = config.data
|
| 161 |
+
runtime = config.runtime
|
| 162 |
+
|
| 163 |
+
self.num_audio_channels = max(0, data_cfg.channels - 2)
|
| 164 |
+
self.num_depth = max(self.num_audio_channels - 1, 0)
|
| 165 |
+
self.weights_schedule = runtime.weights_schedule
|
| 166 |
+
|
| 167 |
+
self.audio_embeds = nn.ModuleList(
|
| 168 |
+
[nn.Embedding(data_cfg.audio_vocab_size, dep_cfg.n_embd) for _ in range(self.num_depth)]
|
| 169 |
+
)
|
| 170 |
+
if dep_cfg.text_embedding:
|
| 171 |
+
self.text_embed = MultiStreamEmbedding(
|
| 172 |
+
data_cfg.text_vocab_size,
|
| 173 |
+
dep_cfg.n_embd,
|
| 174 |
+
pad_id=data_cfg.text_pad_token_id,
|
| 175 |
+
output_dtype=precision.compute,
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
self.text_embed = None
|
| 179 |
+
|
| 180 |
+
used_ids = sorted(set(self.weights_schedule))
|
| 181 |
+
self.depformer_in = nn.ModuleDict(
|
| 182 |
+
{
|
| 183 |
+
str(i): nn.Linear(
|
| 184 |
+
config.model.decoder.n_embd,
|
| 185 |
+
dep_cfg.n_embd,
|
| 186 |
+
bias=False,
|
| 187 |
+
)
|
| 188 |
+
for i in used_ids
|
| 189 |
+
}
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.layers = nn.ModuleList([DepformerLayer(config, precision.compute) for _ in range(dep_cfg.n_layer)])
|
| 193 |
+
self.norm = nn.RMSNorm(dep_cfg.n_embd, eps=config.model.normalization_layer_epsilon)
|
| 194 |
+
self.logits_dtype = precision.logits
|
| 195 |
+
self.logits = nn.ModuleList(
|
| 196 |
+
[
|
| 197 |
+
nn.Linear(dep_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
|
| 198 |
+
for _ in range(self.num_depth)
|
| 199 |
+
]
|
| 200 |
+
)
|
| 201 |
+
self.audio_vocab_limit = min(data_cfg.audio_pad_token_id, data_cfg.audio_bos_token_id)
|
| 202 |
+
|
| 203 |
+
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
|
| 204 |
+
heads = self.layers[0].self_attention.num_kv_heads
|
| 205 |
+
head_dim = self.layers[0].self_attention.head_dim
|
| 206 |
+
return KVCache.allocate(
|
| 207 |
+
num_layers=len(self.layers),
|
| 208 |
+
batch_size=batch_size,
|
| 209 |
+
heads=heads,
|
| 210 |
+
max_steps=max_steps,
|
| 211 |
+
head_dim=head_dim,
|
| 212 |
+
device=device,
|
| 213 |
+
dtype=self.precision.compute,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def forward_step(
|
| 217 |
+
self,
|
| 218 |
+
prev_audio: torch.Tensor,
|
| 219 |
+
transformer_out: torch.Tensor,
|
| 220 |
+
stage_index: int,
|
| 221 |
+
cache: KVCache,
|
| 222 |
+
main_text: Optional[torch.Tensor],
|
| 223 |
+
second_text: Optional[torch.Tensor],
|
| 224 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 225 |
+
self._validate_inputs(stage_index, cache)
|
| 226 |
+
return self._forward_stage(stage_index, prev_audio, transformer_out, cache, main_text, second_text)
|
| 227 |
+
|
| 228 |
+
def _forward_stage(
|
| 229 |
+
self,
|
| 230 |
+
stage_index: int,
|
| 231 |
+
prev_audio: torch.Tensor,
|
| 232 |
+
transformer_out: torch.Tensor,
|
| 233 |
+
cache: KVCache,
|
| 234 |
+
main_text: Optional[torch.Tensor],
|
| 235 |
+
second_text: Optional[torch.Tensor],
|
| 236 |
+
) -> Tuple[torch.Tensor, KVCache]:
|
| 237 |
+
prev_audio = prev_audio.long()
|
| 238 |
+
weight_idx = self.weights_schedule[stage_index]
|
| 239 |
+
token_emb = self.audio_embeds[stage_index](prev_audio[:, None]).to(self.precision.compute)
|
| 240 |
+
if stage_index == 0 and self.text_embed is not None:
|
| 241 |
+
if main_text is None or second_text is None:
|
| 242 |
+
raise ValueError("stage 0 requires text tokens")
|
| 243 |
+
token_emb = token_emb + self.text_embed(main_text[:, None], second_text[:, None])
|
| 244 |
+
|
| 245 |
+
dep_in = self.depformer_in[str(weight_idx)](transformer_out.to(torch.float32))
|
| 246 |
+
dep_in = dep_in.to(self.precision.compute)
|
| 247 |
+
dep_in = dep_in + token_emb.to(dep_in.dtype)
|
| 248 |
+
x = dep_in
|
| 249 |
+
for idx, layer in enumerate(self.layers):
|
| 250 |
+
slot = cache.get_slot(idx)
|
| 251 |
+
x, _ = layer.decode_step(x, stage_index, slot)
|
| 252 |
+
|
| 253 |
+
hidden = self.norm(x)
|
| 254 |
+
logits = self.logits[stage_index](hidden.to(torch.float32))
|
| 255 |
+
logits = logits.to(self.logits_dtype)
|
| 256 |
+
logits = logits.unsqueeze(1)
|
| 257 |
+
logits = logits[..., : self.audio_vocab_limit]
|
| 258 |
+
return logits, cache
|
| 259 |
+
|
| 260 |
+
def _validate_inputs(self, stage_index: int, cache: KVCache | None) -> None:
|
| 261 |
+
if stage_index < 0 or stage_index >= self.num_depth:
|
| 262 |
+
raise ValueError(f"stage_index {stage_index} out of range (depth={self.num_depth})")
|
| 263 |
+
if cache is None:
|
| 264 |
+
raise ValueError("depformer cache must be initialized")
|
core/layers.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional, Tuple, Union, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RotaryEmbedding(nn.Module):
|
| 13 |
+
def __init__(self, head_dim: int, min_timescale: int, max_timescale: int):
|
| 14 |
+
super().__init__()
|
| 15 |
+
if head_dim % 2 != 0:
|
| 16 |
+
raise ValueError("RoPE dimension must be even")
|
| 17 |
+
half_dim = head_dim // 2
|
| 18 |
+
fraction = (2.0 * torch.arange(0, half_dim)) / head_dim
|
| 19 |
+
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
|
| 20 |
+
inv_freq = 1.0 / timescale
|
| 21 |
+
self.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
pos = position_ids.to(self.inv_freq.dtype)
|
| 25 |
+
freqs = torch.einsum("...i,j->...ij", pos, self.inv_freq)
|
| 26 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 27 |
+
while emb.dim() < x.dim():
|
| 28 |
+
emb = emb.unsqueeze(-2)
|
| 29 |
+
cos = emb.cos().to(x.dtype)
|
| 30 |
+
sin = emb.sin().to(x.dtype)
|
| 31 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 32 |
+
rotated = torch.cat((-x2, x1), dim=-1)
|
| 33 |
+
return (x * cos) + (rotated * sin)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
x1 = x[..., ::2]
|
| 38 |
+
x2 = x[..., 1::2]
|
| 39 |
+
return torch.stack((-x2, x1), dim=-1).reshape_as(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_activation(name: str) -> nn.Module:
|
| 43 |
+
name = name.lower()
|
| 44 |
+
if name in ("silu", "swish", "swiglu"):
|
| 45 |
+
return nn.SiLU()
|
| 46 |
+
if name in ("gelu", "geglu"):
|
| 47 |
+
return nn.GELU()
|
| 48 |
+
if name == "relu":
|
| 49 |
+
return nn.ReLU()
|
| 50 |
+
if name == "linear":
|
| 51 |
+
return nn.Identity()
|
| 52 |
+
raise ValueError(f"Unsupported activation {name}")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class AttentionShape:
|
| 57 |
+
dim: int
|
| 58 |
+
heads: int
|
| 59 |
+
kv_heads: int
|
| 60 |
+
head_dim: int
|
| 61 |
+
rope_min: int
|
| 62 |
+
rope_max: int
|
| 63 |
+
apply_rope: bool
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Attention(nn.Module):
|
| 67 |
+
"""Byte-for-byte port of dia_v2 Attention.forward_incremental."""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: DiaConfig, dim: int, compute_dtype: torch.dtype) -> None:
|
| 70 |
+
super().__init__()
|
| 71 |
+
dec = config.model.decoder
|
| 72 |
+
self.num_query_heads = dec.gqa_query_heads
|
| 73 |
+
self.num_kv_heads = dec.kv_heads
|
| 74 |
+
self.head_dim = dec.gqa_head_dim
|
| 75 |
+
self.num_gqa_groups = self.num_query_heads // max(self.num_kv_heads, 1)
|
| 76 |
+
self.compute_dtype = compute_dtype
|
| 77 |
+
self.q_proj = nn.Linear(dim, self.num_query_heads * self.head_dim, bias=False)
|
| 78 |
+
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
| 79 |
+
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
| 80 |
+
self.o_proj = nn.Linear(self.num_query_heads * self.head_dim, dim, bias=False)
|
| 81 |
+
eps = config.model.normalization_layer_epsilon
|
| 82 |
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 83 |
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=eps, dtype=torch.float32)
|
| 84 |
+
self.rotary = RotaryEmbedding(
|
| 85 |
+
self.head_dim,
|
| 86 |
+
config.model.rope_min_timescale,
|
| 87 |
+
config.model.rope_max_timescale,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward_incremental(
|
| 91 |
+
self,
|
| 92 |
+
x: torch.Tensor,
|
| 93 |
+
pos: Optional[torch.Tensor],
|
| 94 |
+
cache_slot,
|
| 95 |
+
) -> Tuple[torch.Tensor, object]:
|
| 96 |
+
B, T, _ = x.shape
|
| 97 |
+
if T != 1:
|
| 98 |
+
raise ValueError("Attention expects sequence length 1 during decoding")
|
| 99 |
+
orig_dtype = x.dtype
|
| 100 |
+
q_proj = self._project_heads(self.q_proj, x, self.num_query_heads)
|
| 101 |
+
k_proj = self._project_heads(self.k_proj, x, self.num_kv_heads)
|
| 102 |
+
v_proj = self._project_heads(self.v_proj, x, self.num_kv_heads)
|
| 103 |
+
q_proj = self.q_norm(q_proj)
|
| 104 |
+
k_proj = self.k_norm(k_proj)
|
| 105 |
+
if pos is not None:
|
| 106 |
+
q_proj = self.rotary(q_proj, pos)
|
| 107 |
+
k_proj = self.rotary(k_proj, pos)
|
| 108 |
+
q = q_proj.transpose(1, 2)
|
| 109 |
+
k = k_proj.transpose(1, 2)
|
| 110 |
+
v = v_proj.transpose(1, 2)
|
| 111 |
+
if cache_slot is not None:
|
| 112 |
+
k_cache, v_cache, attn_mask = cache_slot.write_and_view(k, v)
|
| 113 |
+
else:
|
| 114 |
+
k_cache, v_cache = k, v
|
| 115 |
+
attn_mask = None
|
| 116 |
+
attn = F.scaled_dot_product_attention(
|
| 117 |
+
q,
|
| 118 |
+
k_cache,
|
| 119 |
+
v_cache,
|
| 120 |
+
scale=1.0,
|
| 121 |
+
attn_mask=attn_mask,
|
| 122 |
+
enable_gqa=self.num_gqa_groups > 1,
|
| 123 |
+
)
|
| 124 |
+
attn = attn.transpose(1, 2).contiguous()
|
| 125 |
+
flat = attn.reshape(B, T, self.num_query_heads * self.head_dim)
|
| 126 |
+
out = self.o_proj(flat.to(torch.float32))
|
| 127 |
+
return out.to(orig_dtype), cache_slot
|
| 128 |
+
|
| 129 |
+
def _project_heads(self, layer: nn.Linear, x: torch.Tensor, heads: int) -> torch.Tensor:
|
| 130 |
+
proj = layer(x.to(torch.float32))
|
| 131 |
+
B, T, _ = proj.shape
|
| 132 |
+
proj = proj.view(B, T, heads, self.head_dim)
|
| 133 |
+
return proj.to(self.compute_dtype)
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
x: torch.Tensor,
|
| 138 |
+
positions: Optional[torch.Tensor],
|
| 139 |
+
cache=None,
|
| 140 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 141 |
+
return self.forward_incremental(x, positions, cache)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class MultiStreamEmbedding(nn.Module):
|
| 146 |
+
"""Port of dia_v2 MultiStreamEmbed."""
|
| 147 |
+
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
vocab_size: int,
|
| 151 |
+
dim: int,
|
| 152 |
+
pad_id: int,
|
| 153 |
+
*,
|
| 154 |
+
output_dtype: torch.dtype,
|
| 155 |
+
low_rank_dim: Optional[int] = None,
|
| 156 |
+
) -> None:
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.pad_id = pad_id
|
| 159 |
+
self.dtype = output_dtype
|
| 160 |
+
base_dim = low_rank_dim if low_rank_dim is not None else dim
|
| 161 |
+
self.embedding = nn.Embedding(vocab_size, base_dim)
|
| 162 |
+
self.main_proj = nn.Linear(base_dim, dim, bias=False)
|
| 163 |
+
self.second_proj = nn.Linear(base_dim, dim, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, main_inputs: torch.Tensor, second_inputs: torch.Tensor) -> torch.Tensor:
|
| 166 |
+
main_inputs = main_inputs.long()
|
| 167 |
+
second_inputs = second_inputs.long()
|
| 168 |
+
if self.pad_id is not None:
|
| 169 |
+
second_is_pad = second_inputs == self.pad_id
|
| 170 |
+
else:
|
| 171 |
+
second_is_pad = torch.zeros_like(second_inputs, dtype=torch.bool)
|
| 172 |
+
use_second = ~second_is_pad
|
| 173 |
+
emb_main = self.embedding(main_inputs)
|
| 174 |
+
emb_second = self.embedding(second_inputs)
|
| 175 |
+
out_main = self.main_proj(emb_main.to(torch.float32))
|
| 176 |
+
out_second = self.second_proj(emb_second.to(torch.float32))
|
| 177 |
+
zeros = torch.zeros_like(out_second)
|
| 178 |
+
y = out_main + torch.where(use_second.unsqueeze(-1), out_second, zeros)
|
| 179 |
+
target_dtype = self.dtype if self.dtype is not None else y.dtype
|
| 180 |
+
return y.to(target_dtype)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Mlp(nn.Module):
|
| 184 |
+
"""Port of dia_v2 MlpBlock (two-activation gated MLP)."""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
dim: int,
|
| 189 |
+
hidden: int,
|
| 190 |
+
compute_dtype: torch.dtype,
|
| 191 |
+
activations: Sequence[str],
|
| 192 |
+
) -> None:
|
| 193 |
+
super().__init__()
|
| 194 |
+
if len(activations) != 2:
|
| 195 |
+
raise ValueError("Mlp expects two activation functions.")
|
| 196 |
+
self.dtype = compute_dtype
|
| 197 |
+
self.hidden = hidden
|
| 198 |
+
self.branch_count = len(activations)
|
| 199 |
+
self.wi = nn.Linear(dim, self.branch_count * hidden, bias=False)
|
| 200 |
+
self.wo = nn.Linear(hidden, dim, bias=False)
|
| 201 |
+
self.activation_fns = [_get_activation(activations[0]), _get_activation(activations[1])]
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
proj = self.wi(x.to(torch.float32))
|
| 205 |
+
proj = proj.view(*x.shape[:-1], self.branch_count, self.hidden).to(self.dtype)
|
| 206 |
+
gate, up = proj.unbind(dim=-2)
|
| 207 |
+
hidden = self.activation_fns[0](gate) * self.activation_fns[1](up)
|
| 208 |
+
out = self.wo(hidden.to(torch.float32))
|
| 209 |
+
return out.to(self.dtype)
|
core/model.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from ..config import DiaConfig
|
| 9 |
+
from .cache import KVCache
|
| 10 |
+
from .depformer import Depformer
|
| 11 |
+
from .precision import Precision
|
| 12 |
+
from .transformer import TransformerDecoder
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DecodeState:
|
| 17 |
+
transformer: KVCache
|
| 18 |
+
depformer: KVCache
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Dia2Model(nn.Module):
|
| 22 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.config = config
|
| 25 |
+
self.precision = precision
|
| 26 |
+
self.transformer = TransformerDecoder(config, precision)
|
| 27 |
+
self.depformer = Depformer(config, precision)
|
| 28 |
+
self._cast_norms_to_compute()
|
| 29 |
+
|
| 30 |
+
def init_state(self, batch_size: int, device: torch.device, max_steps: int) -> DecodeState:
|
| 31 |
+
transformer_cache = self.transformer.init_cache(batch_size, device, max_steps)
|
| 32 |
+
depformer_cache = self.depformer.init_cache(batch_size, device, self.depformer.num_depth)
|
| 33 |
+
return DecodeState(transformer_cache, depformer_cache)
|
| 34 |
+
|
| 35 |
+
def step_text(
|
| 36 |
+
self,
|
| 37 |
+
tokens: torch.Tensor,
|
| 38 |
+
positions: torch.Tensor,
|
| 39 |
+
state: DecodeState,
|
| 40 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 41 |
+
hidden, action, cb0, cache = self.transformer.forward_step(tokens, positions, state.transformer)
|
| 42 |
+
state.transformer = cache
|
| 43 |
+
return hidden, action, cb0
|
| 44 |
+
|
| 45 |
+
def step_audio_stage(
|
| 46 |
+
self,
|
| 47 |
+
stage_index: int,
|
| 48 |
+
prev_audio: torch.Tensor,
|
| 49 |
+
transformer_hidden: torch.Tensor,
|
| 50 |
+
state: DecodeState,
|
| 51 |
+
main_text: Optional[torch.Tensor],
|
| 52 |
+
second_text: Optional[torch.Tensor],
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
cache = state.depformer
|
| 55 |
+
logits, new_cache = self.depformer.forward_step(
|
| 56 |
+
prev_audio,
|
| 57 |
+
transformer_hidden,
|
| 58 |
+
stage_index,
|
| 59 |
+
cache,
|
| 60 |
+
main_text,
|
| 61 |
+
second_text,
|
| 62 |
+
)
|
| 63 |
+
state.depformer = new_cache
|
| 64 |
+
return logits
|
| 65 |
+
|
| 66 |
+
def _cast_norms_to_compute(self) -> None:
|
| 67 |
+
"""Cast RMSNorm weights/biases to the compute dtype to avoid bf16 warnings."""
|
| 68 |
+
def _convert(module: nn.Module) -> None:
|
| 69 |
+
if isinstance(module, nn.RMSNorm):
|
| 70 |
+
module.to(self.precision.compute)
|
| 71 |
+
|
| 72 |
+
self.apply(_convert)
|
core/precision.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class Precision:
|
| 10 |
+
compute: torch.dtype
|
| 11 |
+
logits: torch.dtype
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def resolve_precision(kind: str | None, device: torch.device) -> Precision:
|
| 15 |
+
normalized = (kind or "auto").lower()
|
| 16 |
+
if normalized == "auto":
|
| 17 |
+
normalized = "bfloat16" if device.type == "cuda" else "float32"
|
| 18 |
+
if normalized == "bfloat16":
|
| 19 |
+
compute = torch.bfloat16 if device.type == "cuda" else torch.float32
|
| 20 |
+
return Precision(compute=compute, logits=torch.float32)
|
| 21 |
+
if normalized == "float32":
|
| 22 |
+
return Precision(compute=torch.float32, logits=torch.float32)
|
| 23 |
+
raise ValueError(f"Unsupported dtype '{kind}'")
|
core/transformer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from ..config import DiaConfig
|
| 10 |
+
from .cache import KVCache
|
| 11 |
+
from .precision import Precision
|
| 12 |
+
from .layers import (
|
| 13 |
+
AttentionShape,
|
| 14 |
+
MultiStreamEmbedding,
|
| 15 |
+
Mlp,
|
| 16 |
+
Attention,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerDecoder(nn.Module):
|
| 21 |
+
"""Inference-time port of dia_v2.model.Transformer."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.config = config
|
| 26 |
+
self.precision = precision
|
| 27 |
+
data_cfg = config.data
|
| 28 |
+
dec_cfg = config.model.decoder
|
| 29 |
+
|
| 30 |
+
self.audio_embeds = nn.ModuleList(
|
| 31 |
+
[
|
| 32 |
+
nn.Embedding(
|
| 33 |
+
data_cfg.audio_vocab_size,
|
| 34 |
+
dec_cfg.n_embd,
|
| 35 |
+
)
|
| 36 |
+
for _ in range(max(0, data_cfg.channels - 2))
|
| 37 |
+
]
|
| 38 |
+
)
|
| 39 |
+
self.text_embed = MultiStreamEmbedding(
|
| 40 |
+
data_cfg.text_vocab_size,
|
| 41 |
+
dec_cfg.n_embd,
|
| 42 |
+
pad_id=data_cfg.text_pad_token_id,
|
| 43 |
+
output_dtype=self.precision.compute,
|
| 44 |
+
low_rank_dim=dec_cfg.low_rank_dim,
|
| 45 |
+
)
|
| 46 |
+
self.layers = nn.ModuleList([DecoderLayer(config, precision) for _ in range(dec_cfg.n_layer)])
|
| 47 |
+
self.norm = nn.RMSNorm(dec_cfg.n_embd, eps=config.model.normalization_layer_epsilon, dtype=torch.float32)
|
| 48 |
+
|
| 49 |
+
self.action_head = nn.Linear(dec_cfg.n_embd, data_cfg.action_vocab_size, bias=False)
|
| 50 |
+
self.cb0_head = nn.Linear(dec_cfg.n_embd, data_cfg.audio_vocab_size, bias=False)
|
| 51 |
+
|
| 52 |
+
def init_cache(self, batch_size: int, device: torch.device, max_steps: int) -> KVCache:
|
| 53 |
+
heads = self.layers[0].attn.num_kv_heads
|
| 54 |
+
head_dim = self.layers[0].attn.head_dim
|
| 55 |
+
return KVCache.allocate(
|
| 56 |
+
num_layers=len(self.layers),
|
| 57 |
+
batch_size=batch_size,
|
| 58 |
+
heads=heads,
|
| 59 |
+
max_steps=max_steps,
|
| 60 |
+
head_dim=head_dim,
|
| 61 |
+
device=device,
|
| 62 |
+
dtype=self.precision.compute,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward_step(
|
| 66 |
+
self,
|
| 67 |
+
tokens: torch.Tensor,
|
| 68 |
+
positions: torch.Tensor,
|
| 69 |
+
cache: KVCache,
|
| 70 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, KVCache]:
|
| 71 |
+
if cache is None:
|
| 72 |
+
raise ValueError("Transformer cache must be initialized")
|
| 73 |
+
|
| 74 |
+
B, C, T1 = tokens.shape
|
| 75 |
+
if T1 != 1:
|
| 76 |
+
raise ValueError("forward_step expects sequence length 1")
|
| 77 |
+
num_audio_channels = max(0, C - 2)
|
| 78 |
+
|
| 79 |
+
hidden_t = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
|
| 80 |
+
for idx in range(num_audio_channels):
|
| 81 |
+
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
|
| 82 |
+
hidden_t.add_(audio_emb)
|
| 83 |
+
hidden_t = hidden_t.to(self.precision.compute)
|
| 84 |
+
|
| 85 |
+
x = hidden_t
|
| 86 |
+
for idx, layer in enumerate(self.layers):
|
| 87 |
+
slot = cache.get_slot(idx)
|
| 88 |
+
x, _ = layer.decode_step(x, positions, slot)
|
| 89 |
+
|
| 90 |
+
hidden_norm = self.norm(x)
|
| 91 |
+
action_logits = self.action_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
|
| 92 |
+
cb0_logits = self.cb0_head(hidden_norm.to(torch.float32)).to(self.precision.logits)
|
| 93 |
+
return hidden_norm, action_logits, cb0_logits, cache
|
| 94 |
+
|
| 95 |
+
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
B, C, T1 = tokens.shape
|
| 97 |
+
if T1 != 1:
|
| 98 |
+
raise ValueError("_embed expects sequence length 1")
|
| 99 |
+
num_audio_channels = max(0, C - 2)
|
| 100 |
+
text_hidden = self.text_embed(tokens[:, 0, :], tokens[:, 1, :])
|
| 101 |
+
audio_terms: list[torch.Tensor] = []
|
| 102 |
+
for idx in range(num_audio_channels):
|
| 103 |
+
audio_emb = self.audio_embeds[idx](tokens[:, idx + 2, :])
|
| 104 |
+
audio_terms.append(audio_emb)
|
| 105 |
+
hidden = text_hidden
|
| 106 |
+
for term in audio_terms:
|
| 107 |
+
hidden = hidden + term
|
| 108 |
+
final = hidden.to(self.precision.compute)
|
| 109 |
+
return final
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class DecoderLayer(nn.Module):
|
| 113 |
+
def __init__(self, config: DiaConfig, precision: Precision):
|
| 114 |
+
super().__init__()
|
| 115 |
+
dec = config.model.decoder
|
| 116 |
+
eps = config.model.normalization_layer_epsilon
|
| 117 |
+
self.pre_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
|
| 118 |
+
self.attn = Attention(config, dec.n_embd, precision.compute)
|
| 119 |
+
self.post_norm = nn.RMSNorm(dec.n_embd, eps=eps, dtype=torch.float32)
|
| 120 |
+
self.mlp = Mlp(
|
| 121 |
+
dec.n_embd,
|
| 122 |
+
dec.n_hidden,
|
| 123 |
+
precision.compute,
|
| 124 |
+
tuple(config.model.linear.mlp_activations),
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def decode_step(
|
| 128 |
+
self,
|
| 129 |
+
x: torch.Tensor,
|
| 130 |
+
pos: torch.Tensor,
|
| 131 |
+
cache_slot,
|
| 132 |
+
) -> Tuple[torch.Tensor, object]:
|
| 133 |
+
residual = x
|
| 134 |
+
x_norm = self.pre_norm(x)
|
| 135 |
+
attn_out, _ = self.attn(x_norm, pos, cache_slot)
|
| 136 |
+
x = residual + attn_out
|
| 137 |
+
residual2 = x
|
| 138 |
+
x_norm2 = self.post_norm(x)
|
| 139 |
+
mlp_out = self.mlp(x_norm2)
|
| 140 |
+
return residual2 + mlp_out, cache_slot
|
engine.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Sequence
|
| 5 |
+
|
| 6 |
+
from .assets import resolve_assets
|
| 7 |
+
from .runtime.context import RuntimeContext, build_runtime
|
| 8 |
+
from .runtime.generator import (
|
| 9 |
+
build_initial_state,
|
| 10 |
+
decode_audio,
|
| 11 |
+
run_generation_loop,
|
| 12 |
+
warmup_with_prefix,
|
| 13 |
+
)
|
| 14 |
+
from .runtime.script_parser import parse_script
|
| 15 |
+
from .audio.grid import undelay_frames, write_wav
|
| 16 |
+
from .runtime.voice_clone import build_prefix_plan
|
| 17 |
+
from .generation import (
|
| 18 |
+
GenerationConfig,
|
| 19 |
+
GenerationResult,
|
| 20 |
+
merge_generation_config,
|
| 21 |
+
normalize_script,
|
| 22 |
+
)
|
| 23 |
+
from .runtime.logger import RuntimeLogger
|
| 24 |
+
|
| 25 |
+
class Dia2:
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
*,
|
| 29 |
+
repo: Optional[str] = None,
|
| 30 |
+
config_path: Optional[str | Path] = None,
|
| 31 |
+
weights_path: Optional[str | Path] = None,
|
| 32 |
+
tokenizer_id: Optional[str | Path] = None,
|
| 33 |
+
mimi_id: Optional[str] = None,
|
| 34 |
+
device: str = "cuda",
|
| 35 |
+
dtype: str = "auto",
|
| 36 |
+
default_config: Optional[GenerationConfig] = None,
|
| 37 |
+
) -> None:
|
| 38 |
+
bundle = resolve_assets(
|
| 39 |
+
repo=repo,
|
| 40 |
+
config_path=config_path,
|
| 41 |
+
weights_path=weights_path,
|
| 42 |
+
)
|
| 43 |
+
self._config_path = bundle.config_path
|
| 44 |
+
self._weights_path = bundle.weights_path
|
| 45 |
+
self._tokenizer_id = (str(tokenizer_id) if tokenizer_id else None) or bundle.tokenizer_id
|
| 46 |
+
self._repo_id = bundle.repo_id
|
| 47 |
+
self._mimi_id = mimi_id or bundle.mimi_id
|
| 48 |
+
self.device = device
|
| 49 |
+
self._dtype_pref = dtype or "auto"
|
| 50 |
+
self.default_config = default_config or GenerationConfig()
|
| 51 |
+
self._runtime: Optional[RuntimeContext] = None
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_repo(
|
| 55 |
+
cls,
|
| 56 |
+
repo: str,
|
| 57 |
+
*,
|
| 58 |
+
device: str = "cuda",
|
| 59 |
+
dtype: str = "auto",
|
| 60 |
+
tokenizer_id: Optional[str] = None,
|
| 61 |
+
mimi_id: Optional[str] = None,
|
| 62 |
+
) -> "Dia2":
|
| 63 |
+
return cls(repo=repo, device=device, dtype=dtype, tokenizer_id=tokenizer_id, mimi_id=mimi_id)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_local(
|
| 67 |
+
cls,
|
| 68 |
+
config_path: str | Path,
|
| 69 |
+
weights_path: str | Path,
|
| 70 |
+
*,
|
| 71 |
+
device: str = "cuda",
|
| 72 |
+
dtype: str = "auto",
|
| 73 |
+
tokenizer_id: Optional[str | Path] = None,
|
| 74 |
+
mimi_id: Optional[str] = None,
|
| 75 |
+
) -> "Dia2":
|
| 76 |
+
return cls(
|
| 77 |
+
config_path=config_path,
|
| 78 |
+
weights_path=weights_path,
|
| 79 |
+
tokenizer_id=tokenizer_id,
|
| 80 |
+
device=device,
|
| 81 |
+
dtype=dtype,
|
| 82 |
+
mimi_id=mimi_id,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def set_device(self, device: str, *, dtype: Optional[str] = None) -> None:
|
| 86 |
+
desired_dtype = dtype or self._dtype_pref
|
| 87 |
+
if self.device == device and desired_dtype == self._dtype_pref:
|
| 88 |
+
return
|
| 89 |
+
self.device = device
|
| 90 |
+
self._dtype_pref = desired_dtype
|
| 91 |
+
self._runtime = None
|
| 92 |
+
|
| 93 |
+
def close(self) -> None:
|
| 94 |
+
self._runtime = None
|
| 95 |
+
|
| 96 |
+
def _ensure_runtime(self) -> RuntimeContext:
|
| 97 |
+
if self._runtime is None:
|
| 98 |
+
self._runtime = self._build_runtime()
|
| 99 |
+
return self._runtime
|
| 100 |
+
|
| 101 |
+
def generate(
|
| 102 |
+
self,
|
| 103 |
+
script: str | Sequence[str],
|
| 104 |
+
*,
|
| 105 |
+
config: Optional[GenerationConfig] = None,
|
| 106 |
+
output_wav: Optional[str | Path] = None,
|
| 107 |
+
prefix_speaker_1: Optional[str] = None,
|
| 108 |
+
prefix_speaker_2: Optional[str] = None,
|
| 109 |
+
include_prefix: Optional[bool] = None,
|
| 110 |
+
verbose: bool = False,
|
| 111 |
+
**overrides,
|
| 112 |
+
):
|
| 113 |
+
runtime = self._ensure_runtime()
|
| 114 |
+
logger = RuntimeLogger(verbose)
|
| 115 |
+
merged_overrides = dict(overrides)
|
| 116 |
+
if prefix_speaker_1 is not None:
|
| 117 |
+
merged_overrides["prefix_speaker_1"] = prefix_speaker_1
|
| 118 |
+
if prefix_speaker_2 is not None:
|
| 119 |
+
merged_overrides["prefix_speaker_2"] = prefix_speaker_2
|
| 120 |
+
if include_prefix is not None:
|
| 121 |
+
merged_overrides["include_prefix"] = include_prefix
|
| 122 |
+
merged = merge_generation_config(base=config or self.default_config, overrides=merged_overrides)
|
| 123 |
+
max_context = runtime.config.runtime.max_context_steps
|
| 124 |
+
text = normalize_script(script)
|
| 125 |
+
prefix_plan = build_prefix_plan(runtime, merged.prefix)
|
| 126 |
+
entries = []
|
| 127 |
+
if prefix_plan is not None:
|
| 128 |
+
entries.extend(prefix_plan.entries)
|
| 129 |
+
entries.extend(parse_script([text], runtime.tokenizer, runtime.constants, runtime.frame_rate))
|
| 130 |
+
runtime.machine.initial_padding = merged.initial_padding
|
| 131 |
+
logger.event(
|
| 132 |
+
f"starting generation: max_context={max_context} cfg_scale={merged.cfg_scale:.2f} "
|
| 133 |
+
f"device={self.device} dtype={self._dtype_pref}"
|
| 134 |
+
)
|
| 135 |
+
state = runtime.machine.new_state(entries)
|
| 136 |
+
cfg_active = merged.cfg_scale != 1.0
|
| 137 |
+
if cfg_active:
|
| 138 |
+
logger.event(f"classifier-free guidance enabled (scale={merged.cfg_scale:.2f})")
|
| 139 |
+
else:
|
| 140 |
+
logger.event("classifier-free guidance disabled (scale=1.0)")
|
| 141 |
+
gen_state = build_initial_state(
|
| 142 |
+
runtime,
|
| 143 |
+
prefix=prefix_plan,
|
| 144 |
+
)
|
| 145 |
+
include_prefix_audio = bool(prefix_plan and merged.prefix and merged.prefix.include_audio)
|
| 146 |
+
start_step = 0
|
| 147 |
+
if prefix_plan is not None:
|
| 148 |
+
logger.event(f"warming up with prefix ({prefix_plan.aligned_frames} frames)")
|
| 149 |
+
start_step = warmup_with_prefix(runtime, prefix_plan, state, gen_state)
|
| 150 |
+
if include_prefix_audio:
|
| 151 |
+
logger.event("prefix audio will be kept in output")
|
| 152 |
+
else:
|
| 153 |
+
logger.event("prefix audio trimmed from output")
|
| 154 |
+
first_word_frame, audio_buf = run_generation_loop(
|
| 155 |
+
runtime,
|
| 156 |
+
state=state,
|
| 157 |
+
generation=gen_state,
|
| 158 |
+
config=merged,
|
| 159 |
+
start_step=start_step,
|
| 160 |
+
logger=logger,
|
| 161 |
+
)
|
| 162 |
+
aligned = undelay_frames(audio_buf[0], runtime.audio_delays, runtime.constants.audio_pad).unsqueeze(0)
|
| 163 |
+
crop = 0 if include_prefix_audio else max(first_word_frame, 0)
|
| 164 |
+
if crop > 0 and crop < aligned.shape[-1]:
|
| 165 |
+
aligned = aligned[:, :, crop:]
|
| 166 |
+
elif crop >= aligned.shape[-1]:
|
| 167 |
+
crop = 0
|
| 168 |
+
logger.event(f"decoding {aligned.shape[-1]} Mimi frames")
|
| 169 |
+
waveform = decode_audio(runtime, aligned)
|
| 170 |
+
if output_wav is not None:
|
| 171 |
+
write_wav(str(output_wav), waveform.detach().cpu().numpy(), runtime.mimi.sample_rate)
|
| 172 |
+
duration = waveform.shape[-1] / max(runtime.mimi.sample_rate, 1)
|
| 173 |
+
logger.event(f"saved {output_wav} ({duration:.2f}s)")
|
| 174 |
+
frame_rate = max(runtime.frame_rate, 1.0)
|
| 175 |
+
prefix_entry_count = len(prefix_plan.entries) if prefix_plan is not None else 0
|
| 176 |
+
transcript_entries = state.transcript
|
| 177 |
+
if prefix_plan is not None and not include_prefix_audio:
|
| 178 |
+
if len(transcript_entries) > prefix_entry_count:
|
| 179 |
+
transcript_entries = transcript_entries[prefix_entry_count:]
|
| 180 |
+
else:
|
| 181 |
+
transcript_entries = []
|
| 182 |
+
timestamps = []
|
| 183 |
+
for word, step in transcript_entries:
|
| 184 |
+
adj = step - crop
|
| 185 |
+
if adj < 0:
|
| 186 |
+
continue
|
| 187 |
+
timestamps.append((word, adj / frame_rate))
|
| 188 |
+
logger.event(f"generation finished in {logger.elapsed():.2f}s")
|
| 189 |
+
return GenerationResult(aligned, waveform, runtime.mimi.sample_rate, timestamps)
|
| 190 |
+
|
| 191 |
+
def save_wav(self, script: str | Sequence[str], path: str | Path, **kwargs):
|
| 192 |
+
return self.generate(script, output_wav=path, **kwargs)
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def sample_rate(self) -> int:
|
| 196 |
+
return self._ensure_runtime().mimi.sample_rate
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def tokenizer_id(self) -> Optional[str]:
|
| 200 |
+
if self._tokenizer_id:
|
| 201 |
+
return self._tokenizer_id
|
| 202 |
+
if self._runtime is not None:
|
| 203 |
+
return getattr(self._runtime.tokenizer, "name_or_path", None)
|
| 204 |
+
return self._repo_id
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def dtype(self) -> str:
|
| 208 |
+
return self._dtype_pref
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def max_context_steps(self) -> int:
|
| 212 |
+
return self._ensure_runtime().config.runtime.max_context_steps
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def repo(self) -> Optional[str]:
|
| 216 |
+
return self._repo_id
|
| 217 |
+
|
| 218 |
+
def _build_runtime(self) -> RuntimeContext:
|
| 219 |
+
runtime, tokenizer_ref, mimi_ref = build_runtime(
|
| 220 |
+
config_path=self._config_path,
|
| 221 |
+
weights_path=self._weights_path,
|
| 222 |
+
tokenizer_id=self._tokenizer_id,
|
| 223 |
+
repo_id=self._repo_id,
|
| 224 |
+
mimi_id=self._mimi_id,
|
| 225 |
+
device=self.device,
|
| 226 |
+
dtype_pref=self._dtype_pref,
|
| 227 |
+
)
|
| 228 |
+
self._tokenizer_id = tokenizer_ref
|
| 229 |
+
self._mimi_id = mimi_ref
|
| 230 |
+
return runtime
|
generation.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Mapping, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass(frozen=True)
|
| 12 |
+
class SamplingConfig:
|
| 13 |
+
temperature: float = 0.8
|
| 14 |
+
top_k: int = 50
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _default_text_sampling() -> SamplingConfig:
|
| 18 |
+
return SamplingConfig(temperature=0.6, top_k=50)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _default_audio_sampling() -> SamplingConfig:
|
| 22 |
+
return SamplingConfig(temperature=0.8, top_k=50)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class PrefixConfig:
|
| 27 |
+
speaker_1: Optional[str] = None
|
| 28 |
+
speaker_2: Optional[str] = None
|
| 29 |
+
include_audio: bool = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True)
|
| 33 |
+
class GenerationConfig:
|
| 34 |
+
text: SamplingConfig = field(default_factory=_default_text_sampling)
|
| 35 |
+
audio: SamplingConfig = field(default_factory=_default_audio_sampling)
|
| 36 |
+
cfg_scale: float = 2.0
|
| 37 |
+
cfg_filter_k: int = 50
|
| 38 |
+
initial_padding: int = 2
|
| 39 |
+
prefix: Optional["PrefixConfig"] = None
|
| 40 |
+
use_cuda_graph: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class GenerationResult:
|
| 45 |
+
audio_tokens: torch.Tensor
|
| 46 |
+
waveform: torch.Tensor
|
| 47 |
+
sample_rate: int
|
| 48 |
+
timestamps: List[Tuple[str, float]]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def normalize_script(script: str | Sequence[str]) -> str:
|
| 52 |
+
if isinstance(script, str):
|
| 53 |
+
return script.strip()
|
| 54 |
+
return "\n".join(line.strip() for line in script)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_script_text(path: str | Path) -> str:
|
| 58 |
+
if path == "-":
|
| 59 |
+
return sys.stdin.read().strip()
|
| 60 |
+
path_obj = Path(path)
|
| 61 |
+
if path_obj.exists():
|
| 62 |
+
return path_obj.read_text().strip()
|
| 63 |
+
return str(path).strip()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def validate_generation_params(
|
| 67 |
+
*,
|
| 68 |
+
temperature: float,
|
| 69 |
+
top_k: int,
|
| 70 |
+
cfg_scale: float,
|
| 71 |
+
) -> tuple[float, int, float]:
|
| 72 |
+
if temperature <= 0:
|
| 73 |
+
raise ValueError("temperature must be positive")
|
| 74 |
+
if top_k <= 0:
|
| 75 |
+
raise ValueError("top_k must be positive")
|
| 76 |
+
if cfg_scale <= 0:
|
| 77 |
+
raise ValueError("cfg_scale must be positive")
|
| 78 |
+
return temperature, top_k, cfg_scale
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_generation_config(
|
| 82 |
+
*,
|
| 83 |
+
temperature: float,
|
| 84 |
+
top_k: int,
|
| 85 |
+
cfg_scale: float,
|
| 86 |
+
) -> GenerationConfig:
|
| 87 |
+
sampling = SamplingConfig(temperature=temperature, top_k=top_k)
|
| 88 |
+
return GenerationConfig(
|
| 89 |
+
text=sampling,
|
| 90 |
+
audio=sampling,
|
| 91 |
+
cfg_scale=cfg_scale,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def merge_generation_config(
|
| 96 |
+
*,
|
| 97 |
+
base: GenerationConfig,
|
| 98 |
+
overrides: Mapping[str, object],
|
| 99 |
+
) -> GenerationConfig:
|
| 100 |
+
clean_overrides = {k: v for k, v in overrides.items() if v is not None}
|
| 101 |
+
text_temp = clean_overrides.pop("temp_text", None)
|
| 102 |
+
text_topk = clean_overrides.pop("topk_text", None)
|
| 103 |
+
audio_temp = clean_overrides.pop("temp_audio", None)
|
| 104 |
+
audio_topk = clean_overrides.pop("topk_audio", None)
|
| 105 |
+
prefix_speaker_1 = clean_overrides.pop("prefix_speaker_1", None)
|
| 106 |
+
prefix_speaker_2 = clean_overrides.pop("prefix_speaker_2", None)
|
| 107 |
+
include_prefix = clean_overrides.pop("include_prefix", None)
|
| 108 |
+
|
| 109 |
+
text_sampling = base.text
|
| 110 |
+
if text_temp is not None or text_topk is not None:
|
| 111 |
+
text_sampling = SamplingConfig(
|
| 112 |
+
temperature=text_temp if text_temp is not None else text_sampling.temperature,
|
| 113 |
+
top_k=text_topk if text_topk is not None else text_sampling.top_k,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
audio_sampling = base.audio
|
| 117 |
+
if audio_temp is not None or audio_topk is not None:
|
| 118 |
+
audio_sampling = SamplingConfig(
|
| 119 |
+
temperature=audio_temp if audio_temp is not None else audio_sampling.temperature,
|
| 120 |
+
top_k=audio_topk if audio_topk is not None else audio_sampling.top_k,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
prefix_cfg = base.prefix
|
| 124 |
+
if (
|
| 125 |
+
prefix_speaker_1 is not None
|
| 126 |
+
or prefix_speaker_2 is not None
|
| 127 |
+
or include_prefix is not None
|
| 128 |
+
or prefix_cfg is not None
|
| 129 |
+
):
|
| 130 |
+
prefix_cfg = prefix_cfg or PrefixConfig()
|
| 131 |
+
prefix_cfg = PrefixConfig(
|
| 132 |
+
speaker_1=prefix_speaker_1 if prefix_speaker_1 is not None else prefix_cfg.speaker_1,
|
| 133 |
+
speaker_2=prefix_speaker_2 if prefix_speaker_2 is not None else prefix_cfg.speaker_2,
|
| 134 |
+
include_audio=include_prefix if include_prefix is not None else prefix_cfg.include_audio,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return GenerationConfig(
|
| 138 |
+
text=text_sampling,
|
| 139 |
+
audio=audio_sampling,
|
| 140 |
+
cfg_scale=clean_overrides.pop("cfg_scale", base.cfg_scale),
|
| 141 |
+
cfg_filter_k=clean_overrides.pop("cfg_filter_k", base.cfg_filter_k),
|
| 142 |
+
initial_padding=clean_overrides.pop("initial_padding", base.initial_padding),
|
| 143 |
+
prefix=prefix_cfg,
|
| 144 |
+
use_cuda_graph=clean_overrides.pop("use_cuda_graph", base.use_cuda_graph),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
__all__ = [
|
| 149 |
+
"SamplingConfig",
|
| 150 |
+
"GenerationConfig",
|
| 151 |
+
"GenerationResult",
|
| 152 |
+
"PrefixConfig",
|
| 153 |
+
"normalize_script",
|
| 154 |
+
"load_script_text",
|
| 155 |
+
"validate_generation_params",
|
| 156 |
+
"build_generation_config",
|
| 157 |
+
"merge_generation_config",
|
| 158 |
+
]
|
runtime/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .state_machine import Entry, StateMachine, TokenIds
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"Entry",
|
| 5 |
+
"StateMachine",
|
| 6 |
+
"TokenIds",
|
| 7 |
+
]
|
runtime/audio_io.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import sphn
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from ..audio import MimiCodec
|
| 12 |
+
|
| 13 |
+
PathLike = Union[str, Path]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray:
|
| 17 |
+
"""Read an audio file, convert to mono float32, and resample to target_sr."""
|
| 18 |
+
path = str(path)
|
| 19 |
+
try:
|
| 20 |
+
audio, sr = sphn.read_wav(path)
|
| 21 |
+
except Exception:
|
| 22 |
+
import soundfile as sf # Local fallback
|
| 23 |
+
|
| 24 |
+
audio, sr = sf.read(path, dtype="float32", always_2d=False)
|
| 25 |
+
audio = np.asarray(audio, dtype=np.float32)
|
| 26 |
+
if audio.ndim == 2:
|
| 27 |
+
audio = audio.mean(axis=1)
|
| 28 |
+
if sr != target_sr:
|
| 29 |
+
if hasattr(sphn, "resample_audio"):
|
| 30 |
+
audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32)
|
| 31 |
+
else:
|
| 32 |
+
audio = _resample_linear(audio, sr, target_sr)
|
| 33 |
+
return audio
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor:
|
| 37 |
+
"""Convert mono PCM samples into shape [1, 1, T] tensor."""
|
| 38 |
+
tensor = torch.from_numpy(audio).to(device)
|
| 39 |
+
if tensor.dim() == 1:
|
| 40 |
+
tensor = tensor.unsqueeze(0)
|
| 41 |
+
if tensor.dim() == 2:
|
| 42 |
+
tensor = tensor.unsqueeze(0)
|
| 43 |
+
return tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor:
|
| 47 |
+
"""Encode PCM audio into Mimi codebook tokens [C, T]."""
|
| 48 |
+
waveform = audio_to_tensor(audio, mimi.device)
|
| 49 |
+
with torch.inference_mode():
|
| 50 |
+
codes, *_ = mimi.encode(waveform, return_dict=False)
|
| 51 |
+
if isinstance(codes, (tuple, list)):
|
| 52 |
+
codes = codes[0]
|
| 53 |
+
# Mimi.encode returns [B, num_codebooks, T]; select batch 0.
|
| 54 |
+
codes = codes[0].to(torch.long)
|
| 55 |
+
return codes
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
|
| 59 |
+
if src_sr == dst_sr:
|
| 60 |
+
return audio.astype(np.float32)
|
| 61 |
+
length = audio.shape[0]
|
| 62 |
+
new_length = max(1, int(round(length * dst_sr / src_sr)))
|
| 63 |
+
tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0)
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False)
|
| 66 |
+
return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
__all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"]
|
runtime/context.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
| 11 |
+
|
| 12 |
+
from ..config import DiaConfig, load_config
|
| 13 |
+
from ..core.model import Dia2Model
|
| 14 |
+
from ..core.precision import Precision, resolve_precision
|
| 15 |
+
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID
|
| 16 |
+
from .state_machine import StateMachine, TokenIds
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class RuntimeContext:
|
| 21 |
+
config: DiaConfig
|
| 22 |
+
model: Dia2Model
|
| 23 |
+
precision: Precision
|
| 24 |
+
tokenizer: PreTrainedTokenizerBase
|
| 25 |
+
mimi: MimiCodec
|
| 26 |
+
device: torch.device
|
| 27 |
+
machine: StateMachine
|
| 28 |
+
transformer_step: callable
|
| 29 |
+
depformer_step: callable
|
| 30 |
+
constants: TokenIds
|
| 31 |
+
audio_delays: list[int]
|
| 32 |
+
audio_delay_tensor: torch.Tensor
|
| 33 |
+
frame_rate: float
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_runtime(
|
| 37 |
+
*,
|
| 38 |
+
config_path: str | Path,
|
| 39 |
+
weights_path: str | Path,
|
| 40 |
+
tokenizer_id: Optional[str],
|
| 41 |
+
repo_id: Optional[str],
|
| 42 |
+
mimi_id: Optional[str],
|
| 43 |
+
device: str,
|
| 44 |
+
dtype_pref: str,
|
| 45 |
+
) -> tuple[RuntimeContext, str, str]:
|
| 46 |
+
device_obj = torch.device(device)
|
| 47 |
+
if device_obj.type == "cuda":
|
| 48 |
+
cuda_matmul = torch.backends.cuda.matmul
|
| 49 |
+
if hasattr(cuda_matmul, "fp32_precision"):
|
| 50 |
+
cuda_matmul.fp32_precision = "tf32"
|
| 51 |
+
with warnings.catch_warnings():
|
| 52 |
+
warnings.filterwarnings(
|
| 53 |
+
"ignore",
|
| 54 |
+
message="Please use the new API settings",
|
| 55 |
+
)
|
| 56 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 57 |
+
else: # pragma: no cover - compatibility with older PyTorch
|
| 58 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 59 |
+
|
| 60 |
+
# Handle cuDNN conv TF32 settings (check if conv attribute exists first)
|
| 61 |
+
if hasattr(torch.backends.cudnn, "conv"):
|
| 62 |
+
cudnn_conv = torch.backends.cudnn.conv
|
| 63 |
+
if hasattr(cudnn_conv, "fp32_precision"):
|
| 64 |
+
cudnn_conv.fp32_precision = "tf32"
|
| 65 |
+
with warnings.catch_warnings():
|
| 66 |
+
warnings.filterwarnings(
|
| 67 |
+
"ignore",
|
| 68 |
+
message="Please use the new API settings",
|
| 69 |
+
)
|
| 70 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 71 |
+
else:
|
| 72 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 73 |
+
else:
|
| 74 |
+
# For older PyTorch versions without the conv attribute
|
| 75 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 76 |
+
precision = resolve_precision(dtype_pref, device_obj)
|
| 77 |
+
config = load_config(config_path)
|
| 78 |
+
model = Dia2Model(config, precision)
|
| 79 |
+
state = load_file(str(weights_path))
|
| 80 |
+
model.load_state_dict(state)
|
| 81 |
+
model = model.to(device_obj)
|
| 82 |
+
|
| 83 |
+
tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id
|
| 84 |
+
if tokenizer_ref is None:
|
| 85 |
+
raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.")
|
| 86 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 87 |
+
tokenizer_ref,
|
| 88 |
+
use_fast=False,
|
| 89 |
+
trust_remote_code=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID
|
| 93 |
+
mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj)
|
| 94 |
+
|
| 95 |
+
data_cfg = config.data
|
| 96 |
+
constants = TokenIds(
|
| 97 |
+
card=data_cfg.text_vocab_size,
|
| 98 |
+
new_word=data_cfg.text_new_word_token_id,
|
| 99 |
+
pad=data_cfg.text_pad_token_id,
|
| 100 |
+
bos=getattr(tokenizer, "bos_token_id", 1) or 1,
|
| 101 |
+
zero=data_cfg.text_zero_token_id,
|
| 102 |
+
spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
|
| 103 |
+
spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id,
|
| 104 |
+
audio_pad=data_cfg.audio_pad_token_id,
|
| 105 |
+
audio_bos=data_cfg.audio_bos_token_id,
|
| 106 |
+
)
|
| 107 |
+
machine = StateMachine(
|
| 108 |
+
token_ids=constants,
|
| 109 |
+
second_stream_ahead=data_cfg.second_stream_ahead,
|
| 110 |
+
max_padding=6,
|
| 111 |
+
initial_padding=0,
|
| 112 |
+
)
|
| 113 |
+
audio_delays = list(data_cfg.delay_pattern)
|
| 114 |
+
audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj)
|
| 115 |
+
frame_rate = getattr(mimi, "frame_rate", 75.0)
|
| 116 |
+
|
| 117 |
+
runtime = RuntimeContext(
|
| 118 |
+
config=config,
|
| 119 |
+
precision=precision,
|
| 120 |
+
model=model,
|
| 121 |
+
tokenizer=tokenizer,
|
| 122 |
+
mimi=mimi,
|
| 123 |
+
device=device_obj,
|
| 124 |
+
machine=machine,
|
| 125 |
+
constants=constants,
|
| 126 |
+
audio_delays=audio_delays,
|
| 127 |
+
audio_delay_tensor=audio_delay_tensor,
|
| 128 |
+
frame_rate=frame_rate,
|
| 129 |
+
transformer_step=model.transformer.forward_step,
|
| 130 |
+
depformer_step=model.depformer.forward_step,
|
| 131 |
+
)
|
| 132 |
+
return runtime, tokenizer_ref, mimi_ref
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
__all__ = [
|
| 136 |
+
"RuntimeContext",
|
| 137 |
+
"build_runtime",
|
| 138 |
+
]
|
runtime/generator.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from ..core.cache import KVCache
|
| 9 |
+
from ..core.model import DecodeState
|
| 10 |
+
from ..generation import GenerationConfig
|
| 11 |
+
from ..audio.grid import delay_frames, mask_audio_logits, undelay_frames
|
| 12 |
+
from .context import RuntimeContext
|
| 13 |
+
from .state_machine import State, TokenIds
|
| 14 |
+
from .guidance import apply_classifier_guidance, sample_audio_logits
|
| 15 |
+
from .sampler import sample_token
|
| 16 |
+
from .voice_clone import PrefixPlan
|
| 17 |
+
from .logger import RuntimeLogger
|
| 18 |
+
|
| 19 |
+
_GRAPH_CUBLAS_READY = False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _ensure_graph_cublas_ready(device: torch.device) -> None:
|
| 23 |
+
global _GRAPH_CUBLAS_READY
|
| 24 |
+
if _GRAPH_CUBLAS_READY or device.type != "cuda":
|
| 25 |
+
return
|
| 26 |
+
tmp = torch.empty((1, 1), device=device, dtype=torch.float32)
|
| 27 |
+
torch.matmul(tmp, tmp)
|
| 28 |
+
torch.cuda.synchronize()
|
| 29 |
+
_GRAPH_CUBLAS_READY = True
|
| 30 |
+
@dataclass
|
| 31 |
+
class GenerationState:
|
| 32 |
+
decode: DecodeState
|
| 33 |
+
step_tokens: torch.Tensor
|
| 34 |
+
audio_buf: torch.Tensor
|
| 35 |
+
|
| 36 |
+
def trim_audio(self, limit: int, pad_token: int, ungenerated: int) -> torch.Tensor:
|
| 37 |
+
trimmed = self.audio_buf[:, :, :limit]
|
| 38 |
+
pad = torch.full_like(trimmed, pad_token)
|
| 39 |
+
trimmed = torch.where(trimmed == ungenerated, pad, trimmed)
|
| 40 |
+
self.audio_buf = trimmed
|
| 41 |
+
return trimmed
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def transformer_cache(self) -> KVCache:
|
| 45 |
+
return self.decode.transformer
|
| 46 |
+
|
| 47 |
+
@transformer_cache.setter
|
| 48 |
+
def transformer_cache(self, cache: KVCache) -> None:
|
| 49 |
+
self.decode.transformer = cache
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def depformer_cache(self) -> KVCache:
|
| 53 |
+
return self.decode.depformer
|
| 54 |
+
|
| 55 |
+
@depformer_cache.setter
|
| 56 |
+
def depformer_cache(self, cache: KVCache) -> None:
|
| 57 |
+
self.decode.depformer = cache
|
| 58 |
+
|
| 59 |
+
def reset_dep_cache(self) -> None:
|
| 60 |
+
self.decode.depformer.reset()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
class NetworkBuffers:
|
| 65 |
+
text: torch.Tensor
|
| 66 |
+
cb0: torch.Tensor
|
| 67 |
+
dep: list[torch.Tensor]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _allocate_network_buffers(runtime: RuntimeContext, branches: int) -> NetworkBuffers:
|
| 71 |
+
device = runtime.device
|
| 72 |
+
logits_dtype = runtime.precision.logits
|
| 73 |
+
data_cfg = runtime.config.data
|
| 74 |
+
text_logits = torch.empty((branches, 1, data_cfg.action_vocab_size), dtype=logits_dtype, device=device)
|
| 75 |
+
cb0_logits = torch.empty((branches, 1, data_cfg.audio_vocab_size), dtype=logits_dtype, device=device)
|
| 76 |
+
dep_vocab = runtime.model.depformer.audio_vocab_limit or data_cfg.audio_vocab_size
|
| 77 |
+
dep_logits = [
|
| 78 |
+
torch.empty((branches, 1, 1, dep_vocab), dtype=logits_dtype, device=device)
|
| 79 |
+
for _ in range(runtime.model.depformer.num_depth)
|
| 80 |
+
]
|
| 81 |
+
return NetworkBuffers(text=text_logits, cb0=cb0_logits, dep=dep_logits)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def build_initial_state(
|
| 85 |
+
runtime: RuntimeContext,
|
| 86 |
+
*,
|
| 87 |
+
prefix: PrefixPlan | None = None,
|
| 88 |
+
) -> GenerationState:
|
| 89 |
+
dep_q = runtime.model.depformer.num_audio_channels
|
| 90 |
+
channels = 2 + dep_q
|
| 91 |
+
branches = 2
|
| 92 |
+
token_ids = runtime.constants
|
| 93 |
+
step_tokens = torch.full(
|
| 94 |
+
(branches, channels, 1),
|
| 95 |
+
token_ids.pad,
|
| 96 |
+
dtype=torch.long,
|
| 97 |
+
device=runtime.device,
|
| 98 |
+
)
|
| 99 |
+
step_tokens[0, 0, 0] = token_ids.bos
|
| 100 |
+
step_tokens[0, 1, 0] = token_ids.pad
|
| 101 |
+
step_tokens[1, 0, 0] = token_ids.zero
|
| 102 |
+
step_tokens[1, 1, 0] = token_ids.pad
|
| 103 |
+
prefix_len = 0
|
| 104 |
+
if prefix is not None:
|
| 105 |
+
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad)
|
| 106 |
+
prefix_len = delayed.shape[1]
|
| 107 |
+
limit = runtime.config.runtime.max_context_steps
|
| 108 |
+
total_steps = max(limit + prefix_len + 1, limit)
|
| 109 |
+
decode_state = runtime.model.init_state(branches, runtime.device, total_steps)
|
| 110 |
+
audio_buf = torch.full(
|
| 111 |
+
(branches, dep_q, total_steps),
|
| 112 |
+
token_ids.ungenerated,
|
| 113 |
+
dtype=torch.long,
|
| 114 |
+
device=runtime.device,
|
| 115 |
+
)
|
| 116 |
+
if prefix is not None:
|
| 117 |
+
delayed = delay_frames(prefix.aligned_tokens, runtime.audio_delays, token_ids.audio_pad).to(runtime.device)
|
| 118 |
+
audio_buf[0, :, : delayed.shape[1]] = delayed
|
| 119 |
+
if branches > 1:
|
| 120 |
+
audio_buf[1:, :, : delayed.shape[1]] = delayed
|
| 121 |
+
return GenerationState(decode_state, step_tokens, audio_buf)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _fill_audio_channels(
|
| 125 |
+
step_tokens: torch.Tensor,
|
| 126 |
+
audio_buf: torch.Tensor,
|
| 127 |
+
delays: torch.Tensor,
|
| 128 |
+
step: int,
|
| 129 |
+
bos_token: int,
|
| 130 |
+
) -> None:
|
| 131 |
+
channels = delays.numel()
|
| 132 |
+
if channels == 0:
|
| 133 |
+
return
|
| 134 |
+
target = step_tokens[:, 2 : 2 + channels, 0]
|
| 135 |
+
if step < audio_buf.shape[-1]:
|
| 136 |
+
target.copy_(audio_buf[:, :channels, step])
|
| 137 |
+
else:
|
| 138 |
+
target.fill_(bos_token)
|
| 139 |
+
mask = delays > step
|
| 140 |
+
if mask.any().item():
|
| 141 |
+
target[:, mask] = bos_token
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _execute_transformer_step(
|
| 145 |
+
step_tokens: torch.Tensor,
|
| 146 |
+
positions_view: torch.Tensor,
|
| 147 |
+
generation: GenerationState,
|
| 148 |
+
transformer_step,
|
| 149 |
+
buffers: NetworkBuffers,
|
| 150 |
+
) -> torch.Tensor:
|
| 151 |
+
hidden_t, text_logits_t, cb0_logits_t, present = transformer_step(
|
| 152 |
+
step_tokens,
|
| 153 |
+
positions_view,
|
| 154 |
+
generation.transformer_cache,
|
| 155 |
+
)
|
| 156 |
+
buffers.text.copy_(text_logits_t)
|
| 157 |
+
buffers.cb0.copy_(cb0_logits_t)
|
| 158 |
+
generation.transformer_cache = present
|
| 159 |
+
return hidden_t
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _execute_depformer_stage(
|
| 163 |
+
stage_index: int,
|
| 164 |
+
prev_audio: torch.Tensor,
|
| 165 |
+
hidden_t: torch.Tensor,
|
| 166 |
+
generation: GenerationState,
|
| 167 |
+
depformer_step,
|
| 168 |
+
main_tokens: Optional[torch.Tensor],
|
| 169 |
+
second_tokens: Optional[torch.Tensor],
|
| 170 |
+
buffers: NetworkBuffers,
|
| 171 |
+
) -> None:
|
| 172 |
+
logits_stage, dep_present = depformer_step(
|
| 173 |
+
prev_audio=prev_audio,
|
| 174 |
+
transformer_out=hidden_t,
|
| 175 |
+
stage_index=stage_index,
|
| 176 |
+
cache=generation.depformer_cache,
|
| 177 |
+
main_text=main_tokens if stage_index == 0 else None,
|
| 178 |
+
second_text=second_tokens if stage_index == 0 else None,
|
| 179 |
+
)
|
| 180 |
+
target = buffers.dep[stage_index]
|
| 181 |
+
if logits_stage.shape != target.shape:
|
| 182 |
+
raise RuntimeError(
|
| 183 |
+
f"depformer logits shape mismatch: {logits_stage.shape} vs {target.shape}"
|
| 184 |
+
)
|
| 185 |
+
target.copy_(logits_stage)
|
| 186 |
+
generation.depformer_cache = dep_present
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def run_generation_loop(
|
| 192 |
+
runtime: RuntimeContext,
|
| 193 |
+
*,
|
| 194 |
+
state: State,
|
| 195 |
+
generation: GenerationState,
|
| 196 |
+
config: GenerationConfig,
|
| 197 |
+
start_step: int = 0,
|
| 198 |
+
logger: RuntimeLogger | None = None,
|
| 199 |
+
) -> tuple[Optional[int], torch.Tensor]:
|
| 200 |
+
step_tokens = generation.step_tokens
|
| 201 |
+
audio_buf = generation.audio_buf
|
| 202 |
+
branches = step_tokens.shape[0]
|
| 203 |
+
max_context = runtime.config.runtime.max_context_steps
|
| 204 |
+
if max_context <= 0:
|
| 205 |
+
raise ValueError("Runtime configuration must specify a positive max_context_steps")
|
| 206 |
+
positions = torch.empty(1, 1, dtype=torch.long, device=runtime.device)
|
| 207 |
+
main_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
|
| 208 |
+
aux_tokens = torch.empty(branches, dtype=torch.long, device=runtime.device)
|
| 209 |
+
cfg_active = config.cfg_scale != 1.0
|
| 210 |
+
token_ids = runtime.constants
|
| 211 |
+
delay_tensor = runtime.audio_delay_tensor
|
| 212 |
+
max_delay = int(delay_tensor.max().item()) if delay_tensor.numel() else 0
|
| 213 |
+
flush_tail = max_delay + getattr(runtime.machine, "max_padding", 0)
|
| 214 |
+
first_word_frame: Optional[int] = None
|
| 215 |
+
eos_cutoff: Optional[int] = None
|
| 216 |
+
last_step = start_step - 1
|
| 217 |
+
use_graph = bool(config.use_cuda_graph and runtime.device.type == "cuda")
|
| 218 |
+
transformer_step = runtime.transformer_step
|
| 219 |
+
depformer_step = runtime.depformer_step
|
| 220 |
+
buffers = _allocate_network_buffers(runtime, branches)
|
| 221 |
+
positions_view = positions.expand(branches, -1)
|
| 222 |
+
transformer_capture = None
|
| 223 |
+
dep_captures: list[dict] | None = None
|
| 224 |
+
if use_graph:
|
| 225 |
+
_ensure_graph_cublas_ready(runtime.device)
|
| 226 |
+
processed_steps = 0
|
| 227 |
+
report_interval = 12
|
| 228 |
+
with torch.inference_mode():
|
| 229 |
+
for offset in range(max_context):
|
| 230 |
+
t = start_step + offset
|
| 231 |
+
if eos_cutoff is not None and t >= eos_cutoff:
|
| 232 |
+
break
|
| 233 |
+
if t + 1 >= audio_buf.shape[-1]:
|
| 234 |
+
break
|
| 235 |
+
generation.reset_dep_cache()
|
| 236 |
+
positions.fill_(t)
|
| 237 |
+
_fill_audio_channels(step_tokens, audio_buf, delay_tensor, t, token_ids.audio_bos)
|
| 238 |
+
if branches > 1:
|
| 239 |
+
step_tokens[1:, 0, 0] = token_ids.zero
|
| 240 |
+
step_tokens[1:, 1, 0] = token_ids.pad
|
| 241 |
+
if use_graph:
|
| 242 |
+
if transformer_capture is None:
|
| 243 |
+
torch.cuda.synchronize()
|
| 244 |
+
graph = torch.cuda.CUDAGraph()
|
| 245 |
+
with torch.cuda.graph(graph):
|
| 246 |
+
hidden_ref = _execute_transformer_step(
|
| 247 |
+
step_tokens,
|
| 248 |
+
positions_view,
|
| 249 |
+
generation,
|
| 250 |
+
transformer_step,
|
| 251 |
+
buffers,
|
| 252 |
+
)
|
| 253 |
+
transformer_capture = (graph, hidden_ref)
|
| 254 |
+
if runtime.model.depformer.num_depth > 0:
|
| 255 |
+
dep_captures = []
|
| 256 |
+
for idx in range(runtime.model.depformer.num_depth):
|
| 257 |
+
capture = {
|
| 258 |
+
"graph": torch.cuda.CUDAGraph(),
|
| 259 |
+
"captured": False,
|
| 260 |
+
"prev_audio": torch.empty((branches,), dtype=torch.long, device=runtime.device),
|
| 261 |
+
"main_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
|
| 262 |
+
"second_tokens": torch.empty((branches,), dtype=torch.long, device=runtime.device) if idx == 0 else None,
|
| 263 |
+
}
|
| 264 |
+
dep_captures.append(capture)
|
| 265 |
+
else:
|
| 266 |
+
transformer_capture[0].replay()
|
| 267 |
+
hidden_t = transformer_capture[1]
|
| 268 |
+
else:
|
| 269 |
+
hidden_t = _execute_transformer_step(
|
| 270 |
+
step_tokens,
|
| 271 |
+
positions_view,
|
| 272 |
+
generation,
|
| 273 |
+
transformer_step,
|
| 274 |
+
buffers,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
guided_text = apply_classifier_guidance(buffers.text, cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 278 |
+
if guided_text.shape[0] > 1:
|
| 279 |
+
guided_text = guided_text[:1]
|
| 280 |
+
text_token = sample_token(
|
| 281 |
+
guided_text,
|
| 282 |
+
temp=config.text.temperature,
|
| 283 |
+
top_k=config.text.top_k,
|
| 284 |
+
).item()
|
| 285 |
+
|
| 286 |
+
main_token, aux_token, _ = runtime.machine.process(t, state, text_token)
|
| 287 |
+
second_token = aux_token if aux_token != -1 else token_ids.pad
|
| 288 |
+
if first_word_frame is None and main_token == token_ids.new_word:
|
| 289 |
+
first_word_frame = t - config.initial_padding
|
| 290 |
+
step_tokens[:, 0, 0] = main_token
|
| 291 |
+
step_tokens[:, 1, 0] = second_token
|
| 292 |
+
|
| 293 |
+
guided_cb0 = apply_classifier_guidance(buffers.cb0, cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 294 |
+
if guided_cb0.shape[0] > 1:
|
| 295 |
+
guided_cb0 = guided_cb0[:1]
|
| 296 |
+
masked_cb0 = mask_audio_logits(guided_cb0, token_ids.audio_pad, token_ids.audio_bos)
|
| 297 |
+
codebook_token = sample_audio_logits(masked_cb0, config.audio.temperature, config.audio.top_k)
|
| 298 |
+
audio_buf[:, 0, t + 1] = codebook_token
|
| 299 |
+
|
| 300 |
+
prev_audio = codebook_token.expand(branches)
|
| 301 |
+
main_tokens.fill_(main_token)
|
| 302 |
+
aux_tokens.fill_(second_token)
|
| 303 |
+
for stage in range(runtime.model.depformer.num_depth):
|
| 304 |
+
if use_graph and dep_captures is not None:
|
| 305 |
+
capture = dep_captures[stage]
|
| 306 |
+
capture["prev_audio"].copy_(prev_audio)
|
| 307 |
+
if capture["main_tokens"] is not None and stage == 0:
|
| 308 |
+
capture["main_tokens"].copy_(main_tokens)
|
| 309 |
+
capture["second_tokens"].copy_(aux_tokens)
|
| 310 |
+
if not capture["captured"]:
|
| 311 |
+
torch.cuda.synchronize()
|
| 312 |
+
with torch.cuda.graph(capture["graph"]):
|
| 313 |
+
_execute_depformer_stage(
|
| 314 |
+
stage_index=stage,
|
| 315 |
+
prev_audio=capture["prev_audio"],
|
| 316 |
+
hidden_t=hidden_t,
|
| 317 |
+
generation=generation,
|
| 318 |
+
depformer_step=depformer_step,
|
| 319 |
+
main_tokens=capture["main_tokens"],
|
| 320 |
+
second_tokens=capture["second_tokens"],
|
| 321 |
+
buffers=buffers,
|
| 322 |
+
)
|
| 323 |
+
capture["captured"] = True
|
| 324 |
+
else:
|
| 325 |
+
capture["graph"].replay()
|
| 326 |
+
else:
|
| 327 |
+
_execute_depformer_stage(
|
| 328 |
+
stage_index=stage,
|
| 329 |
+
prev_audio=prev_audio,
|
| 330 |
+
hidden_t=hidden_t,
|
| 331 |
+
generation=generation,
|
| 332 |
+
depformer_step=depformer_step,
|
| 333 |
+
main_tokens=main_tokens,
|
| 334 |
+
second_tokens=aux_tokens,
|
| 335 |
+
buffers=buffers,
|
| 336 |
+
)
|
| 337 |
+
dep_logits = apply_classifier_guidance(buffers.dep[stage], cfg_active, config.cfg_scale, config.cfg_filter_k)
|
| 338 |
+
if dep_logits.shape[0] > 1:
|
| 339 |
+
dep_logits = dep_logits[:1]
|
| 340 |
+
stage_token = sample_audio_logits(
|
| 341 |
+
dep_logits,
|
| 342 |
+
config.audio.temperature,
|
| 343 |
+
config.audio.top_k,
|
| 344 |
+
)
|
| 345 |
+
audio_buf[:, stage + 1, t + 1] = stage_token
|
| 346 |
+
prev_audio = stage_token.expand(branches)
|
| 347 |
+
last_step = t
|
| 348 |
+
if eos_cutoff is None and state.end_step is not None:
|
| 349 |
+
eos_cutoff = state.end_step + flush_tail
|
| 350 |
+
processed_steps = offset + 1
|
| 351 |
+
if logger and processed_steps % report_interval == 0:
|
| 352 |
+
logger.progress(processed_steps, max_context)
|
| 353 |
+
|
| 354 |
+
if logger and processed_steps and processed_steps % report_interval != 0:
|
| 355 |
+
logger.progress(processed_steps, max_context)
|
| 356 |
+
|
| 357 |
+
if first_word_frame is None:
|
| 358 |
+
first_word_frame = start_step
|
| 359 |
+
if last_step < start_step:
|
| 360 |
+
limit = min(start_step + 1, audio_buf.shape[-1])
|
| 361 |
+
else:
|
| 362 |
+
limit = min(last_step + 2, audio_buf.shape[-1])
|
| 363 |
+
trimmed = generation.trim_audio(limit, token_ids.audio_pad, token_ids.ungenerated)
|
| 364 |
+
return first_word_frame, trimmed
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def decode_audio(runtime: RuntimeContext, tokens: torch.Tensor) -> torch.Tensor:
|
| 368 |
+
if tokens.shape[-1] == 0:
|
| 369 |
+
return torch.zeros(0, device=runtime.device)
|
| 370 |
+
with torch.inference_mode():
|
| 371 |
+
pcm = runtime.mimi.decode(tokens.to(runtime.device))
|
| 372 |
+
return pcm[0, 0]
|
| 373 |
+
|
| 374 |
+
def warmup_with_prefix(
|
| 375 |
+
runtime: RuntimeContext,
|
| 376 |
+
plan: PrefixPlan,
|
| 377 |
+
state: State,
|
| 378 |
+
generation: GenerationState,
|
| 379 |
+
) -> int:
|
| 380 |
+
step_tokens = generation.step_tokens
|
| 381 |
+
model_state = generation.decode
|
| 382 |
+
branches = step_tokens.shape[0]
|
| 383 |
+
device = runtime.device
|
| 384 |
+
tokens = plan.aligned_tokens.to(device)
|
| 385 |
+
new_word_steps = set(plan.new_word_steps)
|
| 386 |
+
positions = torch.empty(1, 1, dtype=torch.long, device=device)
|
| 387 |
+
|
| 388 |
+
with torch.inference_mode():
|
| 389 |
+
for t in range(plan.aligned_frames):
|
| 390 |
+
positions.fill_(t)
|
| 391 |
+
channels = tokens.shape[0]
|
| 392 |
+
for cb in range(channels):
|
| 393 |
+
delay = runtime.audio_delays[cb] if cb < len(runtime.audio_delays) else 0
|
| 394 |
+
idx = t - delay
|
| 395 |
+
value = tokens[cb, idx] if idx >= 0 else runtime.constants.audio_bos
|
| 396 |
+
step_tokens[:, 2 + cb, 0] = value
|
| 397 |
+
hidden, text_logits, cb0_logits, present = runtime.model.transformer.forward_step(
|
| 398 |
+
step_tokens,
|
| 399 |
+
positions.expand(branches, -1),
|
| 400 |
+
model_state.transformer,
|
| 401 |
+
)
|
| 402 |
+
model_state.transformer = present
|
| 403 |
+
|
| 404 |
+
forced = runtime.constants.new_word if t in new_word_steps else runtime.constants.pad
|
| 405 |
+
main_token, aux_token, _ = runtime.machine.process(t, state, forced, is_forced=True)
|
| 406 |
+
second_token = runtime.constants.pad if aux_token == -1 else aux_token
|
| 407 |
+
step_tokens[0, 0, 0] = main_token
|
| 408 |
+
step_tokens[0, 1, 0] = second_token
|
| 409 |
+
if branches > 1:
|
| 410 |
+
step_tokens[1:, 0, 0] = runtime.constants.zero
|
| 411 |
+
step_tokens[1:, 1, 0] = runtime.constants.pad
|
| 412 |
+
|
| 413 |
+
return max(plan.aligned_frames - 1, 0)
|
| 414 |
+
__all__ = [
|
| 415 |
+
"build_initial_state",
|
| 416 |
+
"run_generation_loop",
|
| 417 |
+
"decode_audio",
|
| 418 |
+
"warmup_with_prefix",
|
| 419 |
+
"GenerationState",
|
| 420 |
+
]
|
runtime/guidance.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .sampler import sample_token
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_classifier_guidance(
|
| 9 |
+
logits: torch.Tensor,
|
| 10 |
+
cfg_active: bool,
|
| 11 |
+
scale: float,
|
| 12 |
+
top_k: int,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
if not cfg_active:
|
| 15 |
+
return logits
|
| 16 |
+
conditional = logits[0:1]
|
| 17 |
+
unconditional = logits[1:2]
|
| 18 |
+
cond32 = conditional.to(torch.float32)
|
| 19 |
+
uncond32 = unconditional.to(torch.float32)
|
| 20 |
+
guided = torch.lerp(uncond32, cond32, scale)
|
| 21 |
+
if top_k > 0 and guided.shape[-1] > 0:
|
| 22 |
+
k = min(top_k, guided.shape[-1])
|
| 23 |
+
threshold = torch.topk(guided, k=k, dim=-1, sorted=False).values[..., -1:]
|
| 24 |
+
mask = guided >= threshold
|
| 25 |
+
neg_inf = torch.full_like(cond32, float("-inf"))
|
| 26 |
+
cond32 = torch.where(mask, cond32, neg_inf)
|
| 27 |
+
return cond32.to(conditional.dtype)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def sample_audio_logits(logits: torch.Tensor, temp: float, top_k: int) -> torch.Tensor:
|
| 31 |
+
"""Sample a single audio token (shape [1]) from logits."""
|
| 32 |
+
return (
|
| 33 |
+
sample_token(
|
| 34 |
+
logits,
|
| 35 |
+
temp=temp,
|
| 36 |
+
top_k=top_k,
|
| 37 |
+
).view(1)
|
| 38 |
+
)
|
runtime/logger.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
class RuntimeLogger:
|
| 5 |
+
def __init__(self, enabled: bool) -> None:
|
| 6 |
+
self.enabled = enabled
|
| 7 |
+
self.start_time = time.perf_counter()
|
| 8 |
+
self.last_time = self.start_time
|
| 9 |
+
self.last_step = 0
|
| 10 |
+
|
| 11 |
+
def event(self, message: str) -> None:
|
| 12 |
+
if self.enabled:
|
| 13 |
+
print(f"[dia2] {message}")
|
| 14 |
+
|
| 15 |
+
def progress(self, step: int, total: Optional[int] = None) -> None:
|
| 16 |
+
if not self.enabled:
|
| 17 |
+
return
|
| 18 |
+
now = time.perf_counter()
|
| 19 |
+
delta_t = max(now - self.last_time, 1e-6)
|
| 20 |
+
delta_steps = max(step - self.last_step, 1)
|
| 21 |
+
speed = delta_steps / delta_t
|
| 22 |
+
if total is None:
|
| 23 |
+
self.event(f"step {step} :: {speed:.1f} toks/s")
|
| 24 |
+
else:
|
| 25 |
+
self.event(f"step {step}/{total} :: {speed:.1f} toks/s")
|
| 26 |
+
self.last_time = now
|
| 27 |
+
self.last_step = step
|
| 28 |
+
|
| 29 |
+
def elapsed(self) -> float:
|
| 30 |
+
return time.perf_counter() - self.start_time
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__all__ = ["RuntimeLogger"]
|
runtime/sampler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def sample_token(
|
| 7 |
+
logits: torch.Tensor,
|
| 8 |
+
*,
|
| 9 |
+
temp: float,
|
| 10 |
+
top_k: int = 0,
|
| 11 |
+
) -> torch.Tensor:
|
| 12 |
+
logits32 = logits.to(torch.float32)
|
| 13 |
+
if temp <= 0.0:
|
| 14 |
+
return torch.argmax(logits32, dim=-1, keepdim=True)
|
| 15 |
+
probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
|
| 16 |
+
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
| 17 |
+
probs = torch.clamp_min(probs, 0.0)
|
| 18 |
+
flat = probs.reshape(-1, probs.shape[-1])
|
| 19 |
+
norm = flat.sum(dim=-1, keepdim=True)
|
| 20 |
+
zero_mask = norm <= 0
|
| 21 |
+
norm = norm.clamp_min(1e-12)
|
| 22 |
+
flat = flat / norm
|
| 23 |
+
if zero_mask.any():
|
| 24 |
+
filler = torch.zeros_like(flat)
|
| 25 |
+
filler[..., 0] = 1.0
|
| 26 |
+
mask = zero_mask.expand_as(flat)
|
| 27 |
+
flat = torch.where(mask, filler, flat)
|
| 28 |
+
vocab = flat.shape[-1]
|
| 29 |
+
if top_k > 0 and top_k < vocab:
|
| 30 |
+
topv, indices = torch.topk(flat, top_k, dim=-1)
|
| 31 |
+
topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
|
| 32 |
+
draws = torch.multinomial(topv, num_samples=1)
|
| 33 |
+
picks = torch.gather(indices, dim=-1, index=draws)
|
| 34 |
+
else:
|
| 35 |
+
picks = torch.multinomial(flat, num_samples=1)
|
| 36 |
+
picks = picks.reshape(*probs.shape[:-1], 1)
|
| 37 |
+
return picks
|
runtime/script_parser.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Optional, Sequence
|
| 5 |
+
|
| 6 |
+
from .state_machine import Entry
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_script(
|
| 10 |
+
script: Sequence[str],
|
| 11 |
+
tokenizer,
|
| 12 |
+
constants,
|
| 13 |
+
frame_rate: float,
|
| 14 |
+
) -> List[Entry]:
|
| 15 |
+
entries: List[Entry] = []
|
| 16 |
+
speaker_tokens = [constants.spk1, constants.spk2]
|
| 17 |
+
padding_between = 1
|
| 18 |
+
event_re = re.compile(r"(?:<break\s+time=\"([0-9]+(?:.[0-9]*)?)s\"\s*/?>)|(?:\s+)")
|
| 19 |
+
last_speaker_idx = [None]
|
| 20 |
+
|
| 21 |
+
def add_entry(idx: int, word: str, *, pending: Optional[int], first_content: List[bool]):
|
| 22 |
+
tokens: List[int]
|
| 23 |
+
if pending is not None:
|
| 24 |
+
prefix = "[S1]" if pending == constants.spk1 else "[S2]"
|
| 25 |
+
tokens = tokenizer.encode(f"{prefix} {word}", add_special_tokens=False)
|
| 26 |
+
else:
|
| 27 |
+
tokens = tokenizer.encode(word, add_special_tokens=False)
|
| 28 |
+
if first_content[0]:
|
| 29 |
+
if speaker_tokens:
|
| 30 |
+
speaker_idx = idx % len(speaker_tokens)
|
| 31 |
+
speaker_token = speaker_tokens[speaker_idx]
|
| 32 |
+
if speaker_token is not None and last_speaker_idx[0] != speaker_idx:
|
| 33 |
+
if not tokens or tokens[0] != speaker_token:
|
| 34 |
+
tokens.insert(0, speaker_token)
|
| 35 |
+
last_speaker_idx[0] = speaker_idx
|
| 36 |
+
first_content[0] = False
|
| 37 |
+
padding = max(0, padding_between + len(tokens) - 1)
|
| 38 |
+
entries.append(Entry(tokens=tokens, text=word, padding=padding))
|
| 39 |
+
|
| 40 |
+
for idx, line in enumerate(script):
|
| 41 |
+
normalized = line.replace("’", "'").replace(":", " ")
|
| 42 |
+
remaining = normalized
|
| 43 |
+
first_content = [True]
|
| 44 |
+
pending_speaker: Optional[int] = None
|
| 45 |
+
while remaining:
|
| 46 |
+
match = event_re.search(remaining)
|
| 47 |
+
if match is None:
|
| 48 |
+
segment = remaining
|
| 49 |
+
remaining = ""
|
| 50 |
+
else:
|
| 51 |
+
segment = remaining[: match.start()]
|
| 52 |
+
remaining = remaining[match.end() :]
|
| 53 |
+
if segment:
|
| 54 |
+
for raw_word in segment.split():
|
| 55 |
+
if raw_word in ("[S1]", "[S2]"):
|
| 56 |
+
pending_speaker = (
|
| 57 |
+
constants.spk1 if raw_word == "[S1]" else constants.spk2
|
| 58 |
+
)
|
| 59 |
+
continue
|
| 60 |
+
add_entry(idx, raw_word, pending=pending_speaker, first_content=first_content)
|
| 61 |
+
pending_speaker = None
|
| 62 |
+
if match and match.group(1):
|
| 63 |
+
seconds = float(match.group(1))
|
| 64 |
+
padding = int(round(seconds * frame_rate))
|
| 65 |
+
if padding > 0:
|
| 66 |
+
entries.append(Entry(tokens=[], text="", padding=padding))
|
| 67 |
+
if remaining:
|
| 68 |
+
continue
|
| 69 |
+
return entries
|
runtime/state_machine.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import deque
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Deque, Iterable, List, Sequence, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class TokenIds:
|
| 10 |
+
card: int
|
| 11 |
+
new_word: int
|
| 12 |
+
pad: int
|
| 13 |
+
bos: int
|
| 14 |
+
zero: int
|
| 15 |
+
spk1: int
|
| 16 |
+
spk2: int
|
| 17 |
+
audio_pad: int
|
| 18 |
+
audio_bos: int
|
| 19 |
+
ungenerated: int = -2
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class Entry:
|
| 24 |
+
tokens: List[int]
|
| 25 |
+
text: str
|
| 26 |
+
padding: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class State:
|
| 31 |
+
entries: Deque[Entry]
|
| 32 |
+
padding_budget: int
|
| 33 |
+
forced_padding: int
|
| 34 |
+
pending_tokens: Deque[int] = field(default_factory=deque)
|
| 35 |
+
lookahead_tokens: Deque[int] = field(default_factory=deque)
|
| 36 |
+
end_step: int | None = None
|
| 37 |
+
consumption_times: List[int] = field(default_factory=list)
|
| 38 |
+
transcript: List[Tuple[str, int]] = field(default_factory=list)
|
| 39 |
+
|
| 40 |
+
def peek_tokens(self, count: int) -> List[int]:
|
| 41 |
+
"""Return tokens from upcoming entries (used for second-stream lookahead)."""
|
| 42 |
+
assert count > 0
|
| 43 |
+
for entry in self.entries:
|
| 44 |
+
if entry.tokens:
|
| 45 |
+
count -= 1
|
| 46 |
+
if count == 0:
|
| 47 |
+
return entry.tokens
|
| 48 |
+
return []
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class StateMachine:
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
token_ids: TokenIds,
|
| 55 |
+
*,
|
| 56 |
+
second_stream_ahead: int = 0,
|
| 57 |
+
max_padding: int = 6,
|
| 58 |
+
initial_padding: int = 0,
|
| 59 |
+
) -> None:
|
| 60 |
+
self.token_ids = token_ids
|
| 61 |
+
self.second_stream_ahead = second_stream_ahead
|
| 62 |
+
self.max_padding = max_padding
|
| 63 |
+
self.initial_padding = initial_padding
|
| 64 |
+
|
| 65 |
+
def new_state(self, entries: Iterable[Entry]) -> State:
|
| 66 |
+
return State(
|
| 67 |
+
entries=deque(entries),
|
| 68 |
+
padding_budget=self.initial_padding,
|
| 69 |
+
forced_padding=self.initial_padding,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def process(
|
| 73 |
+
self,
|
| 74 |
+
step: int,
|
| 75 |
+
state: State,
|
| 76 |
+
token: int,
|
| 77 |
+
is_forced: bool = False,
|
| 78 |
+
) -> Tuple[int, int, bool]:
|
| 79 |
+
token = self._sanitize_token(token)
|
| 80 |
+
token = self._enforce_token_constraints(state, token, is_forced)
|
| 81 |
+
token, consumed_new_word = self._handle_new_word(step, state, token)
|
| 82 |
+
output_token = self._select_output_token(state, token)
|
| 83 |
+
final_main, final_second = self._maybe_multiplex_second_stream(
|
| 84 |
+
state, output_token
|
| 85 |
+
)
|
| 86 |
+
return final_main, final_second, consumed_new_word
|
| 87 |
+
|
| 88 |
+
def _sanitize_token(self, token: int) -> int:
|
| 89 |
+
if token == 1:
|
| 90 |
+
token = self.token_ids.new_word
|
| 91 |
+
elif token == 0:
|
| 92 |
+
token = self.token_ids.pad
|
| 93 |
+
if token not in (self.token_ids.new_word, self.token_ids.pad):
|
| 94 |
+
return self.token_ids.pad
|
| 95 |
+
return token
|
| 96 |
+
|
| 97 |
+
def _enforce_token_constraints(
|
| 98 |
+
self, state: State, token: int, is_forced: bool
|
| 99 |
+
) -> int:
|
| 100 |
+
if state.pending_tokens:
|
| 101 |
+
return self.token_ids.pad
|
| 102 |
+
if is_forced:
|
| 103 |
+
return token
|
| 104 |
+
if state.forced_padding > 0:
|
| 105 |
+
if token != self.token_ids.pad:
|
| 106 |
+
token = self.token_ids.pad
|
| 107 |
+
return token
|
| 108 |
+
if state.padding_budget <= 0 and token != self.token_ids.new_word:
|
| 109 |
+
return self.token_ids.new_word
|
| 110 |
+
return token
|
| 111 |
+
|
| 112 |
+
def _handle_new_word(
|
| 113 |
+
self, step: int, state: State, token: int
|
| 114 |
+
) -> Tuple[int, bool]:
|
| 115 |
+
if token != self.token_ids.new_word:
|
| 116 |
+
return token, False
|
| 117 |
+
if state.entries:
|
| 118 |
+
entry = state.entries.popleft()
|
| 119 |
+
state.consumption_times.append(step)
|
| 120 |
+
if entry.tokens:
|
| 121 |
+
state.transcript.append((entry.text, step))
|
| 122 |
+
state.pending_tokens.extend(entry.tokens)
|
| 123 |
+
if self.second_stream_ahead:
|
| 124 |
+
state.lookahead_tokens.extend(
|
| 125 |
+
state.peek_tokens(self.second_stream_ahead)
|
| 126 |
+
)
|
| 127 |
+
state.padding_budget = self.max_padding
|
| 128 |
+
else:
|
| 129 |
+
token = self.token_ids.pad
|
| 130 |
+
state.forced_padding = entry.padding
|
| 131 |
+
return token, True
|
| 132 |
+
token = self.token_ids.pad
|
| 133 |
+
if self.second_stream_ahead and state.end_step is None:
|
| 134 |
+
token = self.token_ids.new_word
|
| 135 |
+
if state.end_step is None:
|
| 136 |
+
state.end_step = step
|
| 137 |
+
return token, False
|
| 138 |
+
|
| 139 |
+
def _select_output_token(self, state: State, token: int) -> int:
|
| 140 |
+
if token == self.token_ids.pad:
|
| 141 |
+
if state.padding_budget > 0:
|
| 142 |
+
state.padding_budget -= 1
|
| 143 |
+
if state.forced_padding > 0:
|
| 144 |
+
state.forced_padding -= 1
|
| 145 |
+
if state.pending_tokens:
|
| 146 |
+
return state.pending_tokens.popleft()
|
| 147 |
+
return self.token_ids.pad
|
| 148 |
+
if token == self.token_ids.new_word:
|
| 149 |
+
return self.token_ids.new_word
|
| 150 |
+
if token == self.token_ids.zero:
|
| 151 |
+
return token
|
| 152 |
+
raise RuntimeError(f"Invalid token {token}")
|
| 153 |
+
|
| 154 |
+
def _maybe_multiplex_second_stream(
|
| 155 |
+
self, state: State, output: int
|
| 156 |
+
) -> Tuple[int, int]:
|
| 157 |
+
if not self.second_stream_ahead:
|
| 158 |
+
return output, output
|
| 159 |
+
second = -1
|
| 160 |
+
if output == self.token_ids.new_word:
|
| 161 |
+
second = self.token_ids.new_word
|
| 162 |
+
if state.pending_tokens:
|
| 163 |
+
output = state.pending_tokens.popleft()
|
| 164 |
+
else:
|
| 165 |
+
output = self.token_ids.pad
|
| 166 |
+
elif state.lookahead_tokens:
|
| 167 |
+
second = state.lookahead_tokens.popleft()
|
| 168 |
+
else:
|
| 169 |
+
second = self.token_ids.pad
|
| 170 |
+
return output, second
|
runtime/voice_clone.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Callable, List, Optional, Sequence, TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ..generation import PrefixConfig
|
| 10 |
+
from .audio_io import encode_audio_tokens, load_mono_audio
|
| 11 |
+
from .state_machine import Entry
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING: # pragma: no cover
|
| 14 |
+
from .context import RuntimeContext
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class WhisperWord:
|
| 19 |
+
text: str
|
| 20 |
+
start: float
|
| 21 |
+
end: float
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class PrefixPlan:
|
| 26 |
+
entries: List[Entry]
|
| 27 |
+
new_word_steps: List[int]
|
| 28 |
+
aligned_tokens: torch.Tensor
|
| 29 |
+
aligned_frames: int
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_prefix_plan(
|
| 33 |
+
runtime: "RuntimeContext",
|
| 34 |
+
prefix: Optional[PrefixConfig],
|
| 35 |
+
*,
|
| 36 |
+
transcribe_fn: Optional[Callable[[str, torch.device], List[WhisperWord]]] = None,
|
| 37 |
+
load_audio_fn: Optional[Callable[[str, int], np.ndarray]] = None,
|
| 38 |
+
encode_fn: Optional[Callable[[np.ndarray], torch.Tensor]] = None,
|
| 39 |
+
) -> Optional[PrefixPlan]:
|
| 40 |
+
if prefix is None:
|
| 41 |
+
return None
|
| 42 |
+
if not prefix.speaker_1:
|
| 43 |
+
if prefix.speaker_2:
|
| 44 |
+
raise ValueError("speaker_2 requires speaker_1 to be provided")
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
transcribe = transcribe_fn or (lambda path, device: transcribe_words(path, device))
|
| 48 |
+
load_audio = load_audio_fn or (lambda path, sr: load_mono_audio(path, sr))
|
| 49 |
+
encode_audio = encode_fn or (lambda audio: encode_audio_tokens(runtime.mimi, audio))
|
| 50 |
+
|
| 51 |
+
entries1, steps1, tokens1 = _process_prefix_audio(
|
| 52 |
+
runtime=runtime,
|
| 53 |
+
audio_path=prefix.speaker_1,
|
| 54 |
+
speaker_token=runtime.constants.spk1,
|
| 55 |
+
transcribe=transcribe,
|
| 56 |
+
load_audio=load_audio,
|
| 57 |
+
encode_audio=encode_audio,
|
| 58 |
+
)
|
| 59 |
+
offset = 3 # Match legacy BOS/PAD offset
|
| 60 |
+
entries = list(entries1)
|
| 61 |
+
new_word_steps = [step + offset for step in steps1]
|
| 62 |
+
audio_tokens = tokens1.to(runtime.device)
|
| 63 |
+
|
| 64 |
+
if prefix.speaker_2:
|
| 65 |
+
entries2, steps2, tokens2 = _process_prefix_audio(
|
| 66 |
+
runtime=runtime,
|
| 67 |
+
audio_path=prefix.speaker_2,
|
| 68 |
+
speaker_token=runtime.constants.spk2,
|
| 69 |
+
transcribe=transcribe,
|
| 70 |
+
load_audio=load_audio,
|
| 71 |
+
encode_audio=encode_audio,
|
| 72 |
+
)
|
| 73 |
+
spk1_frames = audio_tokens.shape[-1]
|
| 74 |
+
new_word_steps.extend(step + spk1_frames for step in steps2)
|
| 75 |
+
entries.extend(entries2)
|
| 76 |
+
audio_tokens = torch.cat([audio_tokens, tokens2.to(runtime.device)], dim=1)
|
| 77 |
+
|
| 78 |
+
return PrefixPlan(
|
| 79 |
+
entries=entries,
|
| 80 |
+
new_word_steps=new_word_steps,
|
| 81 |
+
aligned_tokens=audio_tokens,
|
| 82 |
+
aligned_frames=audio_tokens.shape[-1],
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _process_prefix_audio(
|
| 87 |
+
runtime: "RuntimeContext",
|
| 88 |
+
audio_path: str,
|
| 89 |
+
speaker_token: int,
|
| 90 |
+
*,
|
| 91 |
+
transcribe: Callable[[str, torch.device], List[WhisperWord]],
|
| 92 |
+
load_audio: Callable[[str, int], np.ndarray],
|
| 93 |
+
encode_audio: Callable[[np.ndarray], torch.Tensor],
|
| 94 |
+
) -> tuple[List[Entry], List[int], torch.Tensor]:
|
| 95 |
+
words = transcribe(audio_path, runtime.device)
|
| 96 |
+
entries, steps = words_to_entries(
|
| 97 |
+
words=words,
|
| 98 |
+
tokenizer=runtime.tokenizer,
|
| 99 |
+
speaker_token=speaker_token,
|
| 100 |
+
frame_rate=runtime.frame_rate,
|
| 101 |
+
)
|
| 102 |
+
audio = load_audio(audio_path, runtime.mimi.sample_rate)
|
| 103 |
+
tokens = encode_audio(audio)
|
| 104 |
+
return entries, steps, tokens
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def transcribe_words(
|
| 108 |
+
audio_path: str,
|
| 109 |
+
device: torch.device,
|
| 110 |
+
language: Optional[str] = None,
|
| 111 |
+
) -> List[WhisperWord]:
|
| 112 |
+
import whisper_timestamped as wts # Imported lazily
|
| 113 |
+
|
| 114 |
+
model = wts.load_model("openai/whisper-large-v3", device=str(device))
|
| 115 |
+
result = wts.transcribe(model, audio_path, language=language)
|
| 116 |
+
|
| 117 |
+
words: List[WhisperWord] = []
|
| 118 |
+
for segment in result.get("segments", []):
|
| 119 |
+
for word in segment.get("words", []):
|
| 120 |
+
text = (word.get("text") or word.get("word") or "").strip()
|
| 121 |
+
if not text:
|
| 122 |
+
continue
|
| 123 |
+
words.append(
|
| 124 |
+
WhisperWord(
|
| 125 |
+
text=text,
|
| 126 |
+
start=float(word.get("start", 0.0)),
|
| 127 |
+
end=float(word.get("end", 0.0)),
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
return words
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def words_to_entries(
|
| 134 |
+
*,
|
| 135 |
+
words: Sequence[WhisperWord],
|
| 136 |
+
tokenizer,
|
| 137 |
+
speaker_token: int,
|
| 138 |
+
frame_rate: float,
|
| 139 |
+
) -> tuple[List[Entry], List[int]]:
|
| 140 |
+
entries: List[Entry] = []
|
| 141 |
+
new_word_steps: List[int] = []
|
| 142 |
+
if not words:
|
| 143 |
+
return entries, new_word_steps
|
| 144 |
+
|
| 145 |
+
convert = getattr(tokenizer, "convert_tokens_to_ids", None)
|
| 146 |
+
speaker_prefix: Optional[str] = None
|
| 147 |
+
if callable(convert):
|
| 148 |
+
s1_id = convert("[S1]")
|
| 149 |
+
s2_id = convert("[S2]")
|
| 150 |
+
if speaker_token == s1_id:
|
| 151 |
+
speaker_prefix = "[S1]"
|
| 152 |
+
elif speaker_token == s2_id:
|
| 153 |
+
speaker_prefix = "[S2]"
|
| 154 |
+
pending_prefix: Optional[str] = speaker_prefix
|
| 155 |
+
current_pos = 0
|
| 156 |
+
|
| 157 |
+
for idx, word in enumerate(words):
|
| 158 |
+
tokens = _encode_word(word.text, tokenizer, pending_prefix)
|
| 159 |
+
pending_prefix = None
|
| 160 |
+
start_frame = max(current_pos + 1, int(round(word.start * frame_rate)))
|
| 161 |
+
end_frame = start_frame + len(tokens)
|
| 162 |
+
new_word_steps.append(start_frame - 1)
|
| 163 |
+
|
| 164 |
+
if idx < len(words) - 1:
|
| 165 |
+
next_start = int(round(words[idx + 1].start * frame_rate))
|
| 166 |
+
next_word_start = max(end_frame + 1, next_start)
|
| 167 |
+
else:
|
| 168 |
+
end_time = int(round(words[-1].end * frame_rate))
|
| 169 |
+
next_word_start = max(end_frame + 1, end_time)
|
| 170 |
+
|
| 171 |
+
padding = max(0, next_word_start - start_frame - 1)
|
| 172 |
+
entries.append(Entry(tokens=tokens, text=word.text, padding=padding))
|
| 173 |
+
current_pos = end_frame
|
| 174 |
+
|
| 175 |
+
return entries, new_word_steps
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _encode_word(text: str, tokenizer, prefix: Optional[str]) -> List[int]:
|
| 179 |
+
if prefix:
|
| 180 |
+
return tokenizer.encode(f"{prefix} {text}", add_special_tokens=False)
|
| 181 |
+
return tokenizer.encode(text, add_special_tokens=False)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
__all__ = [
|
| 185 |
+
"PrefixPlan",
|
| 186 |
+
"WhisperWord",
|
| 187 |
+
"build_prefix_plan",
|
| 188 |
+
"transcribe_words",
|
| 189 |
+
"words_to_entries",
|
| 190 |
+
]
|