tianfengping.tfp commited on
Commit
7ae3e9e
·
1 Parent(s): 5edc5bc

move model download

Browse files
Files changed (1) hide show
  1. app.py +23 -3
app.py CHANGED
@@ -140,6 +140,11 @@ os.makedirs("./tmp", exist_ok=True)
140
 
141
  def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
142
  # import pdb;pdb.set_trace()
 
 
 
 
 
143
  if not ref_audio and not ref_text:
144
  ref_text = text_prompt.get(speaker, "")
145
  speaker_audio_name = audio_prompt.get(speaker)
@@ -183,7 +188,7 @@ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_t
183
  else:
184
  emotion_info = torch.load("./emotion_info.pt")["male005"][key]
185
 
186
- sample_rate, full_audio = tts_sft.inference_zero_shot(
187
  tts_text,
188
  prompt_text = ref_text,
189
  # speaker=speaker,
@@ -210,6 +215,10 @@ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_t
210
 
211
  def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
212
  # import pdb;pdb.set_trace()
 
 
 
 
213
  if not ref_audio and not ref_text:
214
  ref_text = text_prompt.get(speaker, "")
215
  speaker_audio_name = audio_prompt.get(speaker)
@@ -252,7 +261,7 @@ def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
252
  else:
253
  emotion_info = torch.load("./emotion_info.pt")["male005"][key]
254
 
255
- sample_rate, full_audio = tts_sft.inference_zero_shot(
256
  tts_text,
257
  prompt_text = ref_text,
258
  # speaker=speaker,
@@ -780,10 +789,21 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
780
  outputs=tts_v2_output
781
  )
782
 
 
 
 
 
 
 
 
 
 
 
 
783
  if __name__ == "__main__":
784
  demo.launch(
785
  server_name="0.0.0.0",
786
  server_port=10163,
787
- share=True,
788
  favicon_path=logo_path2
789
  )
 
140
 
141
  def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
142
  # import pdb;pdb.set_trace()
143
+ global tts_speakerminus_global
144
+ if 'tts_speakerminus_global' not in globals():
145
+ print("Loading CosyVoice (speakerminus) model...")
146
+ tts_speakerminus_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path)
147
+
148
  if not ref_audio and not ref_text:
149
  ref_text = text_prompt.get(speaker, "")
150
  speaker_audio_name = audio_prompt.get(speaker)
 
188
  else:
189
  emotion_info = torch.load("./emotion_info.pt")["male005"][key]
190
 
191
+ sample_rate, full_audio = inference_zero_shot.inference_zero_shot(
192
  tts_text,
193
  prompt_text = ref_text,
194
  # speaker=speaker,
 
215
 
216
  def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
217
  # import pdb;pdb.set_trace()
218
+ global tts_sft_global
219
+ if 'tts_sft_global' not in globals():
220
+ print("Loading CosyVoice (SFT enhanced) model...")
221
+ tts_sft_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path_enhenced)
222
  if not ref_audio and not ref_text:
223
  ref_text = text_prompt.get(speaker, "")
224
  speaker_audio_name = audio_prompt.get(speaker)
 
261
  else:
262
  emotion_info = torch.load("./emotion_info.pt")["male005"][key]
263
 
264
+ sample_rate, full_audio = tts_sft_global.inference_zero_shot(
265
  tts_text,
266
  prompt_text = ref_text,
267
  # speaker=speaker,
 
789
  outputs=tts_v2_output
790
  )
791
 
792
+ def preload_models():
793
+ """Pre-download models to cache (non-blocking for launch)"""
794
+ import threading
795
+ def _download():
796
+ print("Pre-downloading models to cache...")
797
+ snapshot_download(repo_id=model_repo_id, repo_type="model")
798
+ print("Model pre-download completed.")
799
+ threading.Thread(target=_download, daemon=True).start()
800
+
801
+ preload_models()
802
+
803
  if __name__ == "__main__":
804
  demo.launch(
805
  server_name="0.0.0.0",
806
  server_port=10163,
807
+ share=False,
808
  favicon_path=logo_path2
809
  )