|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
repo_dir = "VibeVoice" |
|
|
if not os.path.exists(repo_dir): |
|
|
print("Cloning the VibeVoice repository...") |
|
|
try: |
|
|
subprocess.run( |
|
|
["git", "clone", "https://github.com/vibevoice-community/VibeVoice.git"], |
|
|
check=True, capture_output=True, text=True |
|
|
) |
|
|
print("Repository cloned successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error cloning repository: {e.stderr}") |
|
|
sys.exit(1) |
|
|
else: |
|
|
print("Repository already exists. Skipping clone.") |
|
|
|
|
|
|
|
|
os.chdir(repo_dir) |
|
|
print(f"Changed directory to: {os.getcwd()}") |
|
|
|
|
|
print("Installing the VibeVoice package in editable mode...") |
|
|
try: |
|
|
subprocess.run( |
|
|
[sys.executable, "-m", "pip", "install", "-e", "."], |
|
|
check=True, capture_output=True, text=True |
|
|
) |
|
|
print("Package installed successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Error installing package: {e.stderr}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
demo_script_path = Path("demo/gradio_demo.py") |
|
|
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...") |
|
|
|
|
|
try: |
|
|
with open(demo_script_path, 'r') as f: |
|
|
modified_content = f.read() |
|
|
|
|
|
|
|
|
if "import spaces" not in modified_content: |
|
|
modified_content = "import spaces\n" + modified_content |
|
|
|
|
|
|
|
|
original_init_call = " self.load_model()" |
|
|
replacement_init_block = ( |
|
|
" # self.load_model() # Patched: Defer model loading\n" |
|
|
" self.model = None\n" |
|
|
" self.processor = None" |
|
|
) |
|
|
if original_init_call in modified_content: |
|
|
modified_content = modified_content.replace(original_init_call, replacement_init_block, 1) |
|
|
print("Successfully patched __init__ to prevent startup model load.") |
|
|
else: |
|
|
print(f"\033[91mError: Could not find '{original_init_call}' to patch. Startup patch failed.\033[0m") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
original_method_header = """ def generate_podcast_streaming(self, |
|
|
num_speakers: int, |
|
|
script: str, |
|
|
speaker_1: str = None, |
|
|
speaker_2: str = None, |
|
|
speaker_3: str = None, |
|
|
speaker_4: str = None, |
|
|
cfg_scale: float = 1.3) -> Iterator[tuple]: |
|
|
try:""" |
|
|
|
|
|
|
|
|
replacement_method_header = """ @spaces.GPU(duration=120) |
|
|
def generate_podcast_streaming(self, |
|
|
num_speakers: int, |
|
|
script: str, |
|
|
speaker_1: str = None, |
|
|
speaker_2: str = None, |
|
|
speaker_3: str = None, |
|
|
speaker_4: str = None, |
|
|
cfg_scale: float = 1.3) -> Iterator[tuple]: |
|
|
# Patched: Lazy-load model and processor on the GPU worker |
|
|
if self.model is None or self.processor is None: |
|
|
print("Loading processor & model for the first time on GPU worker...") |
|
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) |
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
|
|
self.model_path, |
|
|
torch_dtype=torch.bfloat16, # Use 16-bit precision for quality |
|
|
device_map="auto", |
|
|
) |
|
|
self.model.eval() |
|
|
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( |
|
|
self.model.model.noise_scheduler.config, |
|
|
algorithm_type='sde-dpmsolver++', |
|
|
beta_schedule='squaredcos_cap_v2' |
|
|
) |
|
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) |
|
|
print("Model and processor loaded successfully on GPU worker.") |
|
|
|
|
|
try:""" |
|
|
|
|
|
if original_method_header in modified_content: |
|
|
modified_content = modified_content.replace(original_method_header, replacement_method_header, 1) |
|
|
print("Successfully patched generation method for lazy loading.") |
|
|
else: |
|
|
print(f"\033[91mError: Could not find the method definition for 'generate_podcast_streaming' to patch. This is likely due to a whitespace mismatch. Please check the demo script.\033[0m") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
with open(demo_script_path, 'w') as f: |
|
|
f.write(modified_content) |
|
|
|
|
|
print("Script patching complete.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred while modifying the script: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
model_id = "microsoft/VibeVoice-1.5B" |
|
|
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"] |
|
|
print(f"Launching Gradio demo with command: {' '.join(command)}") |
|
|
subprocess.run(command) |