Spaces:
Running
Running
tianfengping.tfp
commited on
Commit
·
7ae3e9e
1
Parent(s):
5edc5bc
move model download
Browse files
app.py
CHANGED
|
@@ -140,6 +140,11 @@ os.makedirs("./tmp", exist_ok=True)
|
|
| 140 |
|
| 141 |
def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
|
| 142 |
# import pdb;pdb.set_trace()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if not ref_audio and not ref_text:
|
| 144 |
ref_text = text_prompt.get(speaker, "")
|
| 145 |
speaker_audio_name = audio_prompt.get(speaker)
|
|
@@ -183,7 +188,7 @@ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_t
|
|
| 183 |
else:
|
| 184 |
emotion_info = torch.load("./emotion_info.pt")["male005"][key]
|
| 185 |
|
| 186 |
-
sample_rate, full_audio =
|
| 187 |
tts_text,
|
| 188 |
prompt_text = ref_text,
|
| 189 |
# speaker=speaker,
|
|
@@ -210,6 +215,10 @@ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_t
|
|
| 210 |
|
| 211 |
def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
|
| 212 |
# import pdb;pdb.set_trace()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
if not ref_audio and not ref_text:
|
| 214 |
ref_text = text_prompt.get(speaker, "")
|
| 215 |
speaker_audio_name = audio_prompt.get(speaker)
|
|
@@ -252,7 +261,7 @@ def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
|
|
| 252 |
else:
|
| 253 |
emotion_info = torch.load("./emotion_info.pt")["male005"][key]
|
| 254 |
|
| 255 |
-
sample_rate, full_audio =
|
| 256 |
tts_text,
|
| 257 |
prompt_text = ref_text,
|
| 258 |
# speaker=speaker,
|
|
@@ -780,10 +789,21 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 780 |
outputs=tts_v2_output
|
| 781 |
)
|
| 782 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
if __name__ == "__main__":
|
| 784 |
demo.launch(
|
| 785 |
server_name="0.0.0.0",
|
| 786 |
server_port=10163,
|
| 787 |
-
share=
|
| 788 |
favicon_path=logo_path2
|
| 789 |
)
|
|
|
|
| 140 |
|
| 141 |
def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
|
| 142 |
# import pdb;pdb.set_trace()
|
| 143 |
+
global tts_speakerminus_global
|
| 144 |
+
if 'tts_speakerminus_global' not in globals():
|
| 145 |
+
print("Loading CosyVoice (speakerminus) model...")
|
| 146 |
+
tts_speakerminus_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path)
|
| 147 |
+
|
| 148 |
if not ref_audio and not ref_text:
|
| 149 |
ref_text = text_prompt.get(speaker, "")
|
| 150 |
speaker_audio_name = audio_prompt.get(speaker)
|
|
|
|
| 188 |
else:
|
| 189 |
emotion_info = torch.load("./emotion_info.pt")["male005"][key]
|
| 190 |
|
| 191 |
+
sample_rate, full_audio = inference_zero_shot.inference_zero_shot(
|
| 192 |
tts_text,
|
| 193 |
prompt_text = ref_text,
|
| 194 |
# speaker=speaker,
|
|
|
|
| 215 |
|
| 216 |
def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
|
| 217 |
# import pdb;pdb.set_trace()
|
| 218 |
+
global tts_sft_global
|
| 219 |
+
if 'tts_sft_global' not in globals():
|
| 220 |
+
print("Loading CosyVoice (SFT enhanced) model...")
|
| 221 |
+
tts_sft_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path_enhenced)
|
| 222 |
if not ref_audio and not ref_text:
|
| 223 |
ref_text = text_prompt.get(speaker, "")
|
| 224 |
speaker_audio_name = audio_prompt.get(speaker)
|
|
|
|
| 261 |
else:
|
| 262 |
emotion_info = torch.load("./emotion_info.pt")["male005"][key]
|
| 263 |
|
| 264 |
+
sample_rate, full_audio = tts_sft_global.inference_zero_shot(
|
| 265 |
tts_text,
|
| 266 |
prompt_text = ref_text,
|
| 267 |
# speaker=speaker,
|
|
|
|
| 789 |
outputs=tts_v2_output
|
| 790 |
)
|
| 791 |
|
| 792 |
+
def preload_models():
|
| 793 |
+
"""Pre-download models to cache (non-blocking for launch)"""
|
| 794 |
+
import threading
|
| 795 |
+
def _download():
|
| 796 |
+
print("Pre-downloading models to cache...")
|
| 797 |
+
snapshot_download(repo_id=model_repo_id, repo_type="model")
|
| 798 |
+
print("Model pre-download completed.")
|
| 799 |
+
threading.Thread(target=_download, daemon=True).start()
|
| 800 |
+
|
| 801 |
+
preload_models()
|
| 802 |
+
|
| 803 |
if __name__ == "__main__":
|
| 804 |
demo.launch(
|
| 805 |
server_name="0.0.0.0",
|
| 806 |
server_port=10163,
|
| 807 |
+
share=False,
|
| 808 |
favicon_path=logo_path2
|
| 809 |
)
|