+
+## Acknowledge
+
+1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
+2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
+3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
+4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
+5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
+
+## Citations
+
+``` bibtex
+@article{du2024cosyvoice,
+ title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
+ author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
+ journal={arXiv preprint arXiv:2407.05407},
+ year={2024}
+}
+
+@article{du2024cosyvoice,
+ title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
+ author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
+ journal={arXiv preprint arXiv:2412.10117},
+ year={2024}
+}
+
+@article{du2025cosyvoice,
+ title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
+ author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
+ journal={arXiv preprint arXiv:2505.17589},
+ year={2025}
+}
+
+@inproceedings{lyu2025build,
+ title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
+ author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
+ booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
+ pages={1--2},
+ year={2025},
+ organization={IEEE}
+}
+```
+
+## Disclaimer
+The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
diff --git a/CosyVoice/asset/cross_lingual_prompt.wav b/CosyVoice/asset/cross_lingual_prompt.wav
new file mode 100644
index 0000000000000000000000000000000000000000..28780d144d635dcad6b2c64b98e09819b252e9c2
--- /dev/null
+++ b/CosyVoice/asset/cross_lingual_prompt.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:353a7715c2e4811f4045658b29d1ce67ecad5120e09de10ce890f1763aab486c
+size 606404
diff --git a/CosyVoice/asset/zero_shot_prompt.wav b/CosyVoice/asset/zero_shot_prompt.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e8746429bce4bd98c864bd0e166e64f3600ebd58
--- /dev/null
+++ b/CosyVoice/asset/zero_shot_prompt.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd199eb7109fd6ce9943cb297e3cf350c1073af014063dfadbdc100230526243
+size 111496
diff --git a/CosyVoice/cosyvoice/__init__.py b/CosyVoice/cosyvoice/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CosyVoice/cosyvoice/bin/average_model.py b/CosyVoice/cosyvoice/bin/average_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7140c12ef4c0868019b62f2ebe32df99687f66e
--- /dev/null
+++ b/CosyVoice/cosyvoice/bin/average_model.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2020 Mobvoi Inc (Di Wu)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import glob
+
+import yaml
+import torch
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='average model')
+ parser.add_argument('--dst_model', required=True, help='averaged model')
+ parser.add_argument('--src_path',
+ required=True,
+ help='src model path for average')
+ parser.add_argument('--val_best',
+ action="store_true",
+ help='averaged model')
+ parser.add_argument('--num',
+ default=5,
+ type=int,
+ help='nums for averaged model')
+
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ val_scores = []
+ if args.val_best:
+ yamls = glob.glob('{}/*.yaml'.format(args.src_path))
+ yamls = [
+ f for f in yamls
+ if not (os.path.basename(f).startswith('train')
+ or os.path.basename(f).startswith('init'))
+ ]
+ for y in yamls:
+ with open(y, 'r') as f:
+ dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
+ loss = float(dic_yaml['loss_dict']['loss'])
+ epoch = int(dic_yaml['epoch'])
+ step = int(dic_yaml['step'])
+ tag = dic_yaml['tag']
+ val_scores += [[epoch, step, loss, tag]]
+ sorted_val_scores = sorted(val_scores,
+ key=lambda x: x[2],
+ reverse=False)
+ print("best val (epoch, step, loss, tag) = " +
+ str(sorted_val_scores[:args.num]))
+ path_list = [
+ args.src_path + '/epoch_{}_whole.pt'.format(score[0])
+ for score in sorted_val_scores[:args.num]
+ ]
+ print(path_list)
+ avg = {}
+ num = args.num
+ assert num == len(path_list)
+ for path in path_list:
+ print('Processing {}'.format(path))
+ states = torch.load(path, map_location=torch.device('cpu'))
+ for k in states.keys():
+ if k not in ['step', 'epoch']:
+ if k not in avg.keys():
+ avg[k] = states[k].clone()
+ else:
+ avg[k] += states[k]
+ # average
+ for k in avg.keys():
+ if avg[k] is not None:
+ # pytorch 1.6 use true_divide instead of /=
+ avg[k] = torch.true_divide(avg[k], num)
+ print('Saving to {}'.format(args.dst_model))
+ torch.save(avg, args.dst_model)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/CosyVoice/cosyvoice/bin/export_jit.py b/CosyVoice/cosyvoice/bin/export_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eedc1a031eb0e6513b14d950fbb8871356af8ac
--- /dev/null
+++ b/CosyVoice/cosyvoice/bin/export_jit.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import torch
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.utils.file_utils import logging
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def get_optimized_script(model, preserved_attrs=[]):
+ script = torch.jit.script(model)
+ if preserved_attrs != []:
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
+ else:
+ script = torch.jit.freeze(script)
+ script = torch.jit.optimize_for_inference(script)
+ return script
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ torch._C._jit_set_fusion_strategy([('STATIC', 1)])
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_set_profiling_executor(False)
+
+ try:
+ model = CosyVoice(args.model_dir)
+ except Exception:
+ try:
+ model = CosyVoice2(args.model_dir)
+ except Exception:
+ raise TypeError('no valid model_type!')
+
+ if not isinstance(model, CosyVoice2):
+ # 1. export llm text_encoder
+ llm_text_encoder = model.model.llm.text_encoder
+ script = get_optimized_script(llm_text_encoder)
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
+ script = get_optimized_script(llm_text_encoder.half())
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
+ logging.info('successfully export llm_text_encoder')
+
+ # 2. export llm llm
+ llm_llm = model.model.llm.llm
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
+ logging.info('successfully export llm_llm')
+
+ # 3. export flow encoder
+ flow_encoder = model.model.flow.encoder
+ script = get_optimized_script(flow_encoder)
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+ script = get_optimized_script(flow_encoder.half())
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
+ logging.info('successfully export flow_encoder')
+ else:
+ # 3. export flow encoder
+ flow_encoder = model.model.flow.encoder
+ script = get_optimized_script(flow_encoder)
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
+ script = get_optimized_script(flow_encoder.half())
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
+ logging.info('successfully export flow_encoder')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/CosyVoice/cosyvoice/bin/export_onnx.py b/CosyVoice/cosyvoice/bin/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd9f00974d5c45e866f2172a5b14b6181d8ed461
--- /dev/null
+++ b/CosyVoice/cosyvoice/bin/export_onnx.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import sys
+import onnxruntime
+import random
+import torch
+from tqdm import tqdm
+ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+sys.path.append('{}/../..'.format(ROOT_DIR))
+sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
+from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
+from cosyvoice.utils.file_utils import logging
+
+
+def get_dummy_input(batch_size, seq_len, out_channels, device):
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
+ return x, mask, mu, t, spks, cond
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='export your model for deployment')
+ parser.add_argument('--model_dir',
+ type=str,
+ default='pretrained_models/CosyVoice-300M',
+ help='local path')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+@torch.no_grad()
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ try:
+ model = CosyVoice(args.model_dir)
+ except Exception:
+ try:
+ model = CosyVoice2(args.model_dir)
+ except Exception:
+ raise TypeError('no valid model_type!')
+
+ # 1. export flow decoder estimator
+ estimator = model.model.flow.decoder.estimator
+ estimator.eval()
+
+ device = model.model.device
+ batch_size, seq_len = 2, 256
+ out_channels = model.model.flow.decoder.estimator.out_channels
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
+ torch.onnx.export(
+ estimator,
+ (x, mask, mu, t, spks, cond),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ export_params=True,
+ opset_version=18,
+ do_constant_folding=True,
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
+ output_names=['estimator_out'],
+ dynamic_axes={
+ 'x': {2: 'seq_len'},
+ 'mask': {2: 'seq_len'},
+ 'mu': {2: 'seq_len'},
+ 'cond': {2: 'seq_len'},
+ 'estimator_out': {2: 'seq_len'},
+ }
+ )
+
+ # 2. test computation consistency
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
+ sess_options=option, providers=providers)
+
+ for _ in tqdm(range(10)):
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
+ ort_inputs = {
+ 'x': x.cpu().numpy(),
+ 'mask': mask.cpu().numpy(),
+ 'mu': mu.cpu().numpy(),
+ 't': t.cpu().numpy(),
+ 'spks': spks.cpu().numpy(),
+ 'cond': cond.cpu().numpy()
+ }
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
+ logging.info('successfully export estimator')
+
+
+if __name__ == "__main__":
+ main()
diff --git a/CosyVoice/cosyvoice/bin/inference_deprecated.py b/CosyVoice/cosyvoice/bin/inference_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d45cc784ef68b619cdbd66c42e29ecc18a3e14e
--- /dev/null
+++ b/CosyVoice/cosyvoice/bin/inference_deprecated.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+import os
+import torch
+from torch.utils.data import DataLoader
+import torchaudio
+from hyperpyyaml import load_hyperpyyaml
+from tqdm import tqdm
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.dataset.dataset import Dataset
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='inference with your model')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
+ parser.add_argument('--tts_text', required=True, help='tts input file')
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
+ parser.add_argument('--llm_model', required=True, help='llm model file')
+ parser.add_argument('--flow_model', required=True, help='flow model file')
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
+ parser.add_argument('--gpu',
+ type=int,
+ default=-1,
+ help='gpu id for this rank, -1 for cpu')
+ parser.add_argument('--mode',
+ default='sft',
+ choices=['sft', 'zero_shot'],
+ help='inference mode')
+ parser.add_argument('--result_dir', required=True, help='asr result file')
+ args = parser.parse_args()
+ print(args)
+ return args
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
+
+ # Init cosyvoice models from configs
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
+ device = torch.device('cuda' if use_cuda else 'cpu')
+ try:
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
+ model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
+ except Exception:
+ try:
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f)
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
+ except Exception:
+ raise TypeError('no valid model_type!')
+
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
+
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
+
+ sample_rate = configs['sample_rate']
+ del configs
+ os.makedirs(args.result_dir, exist_ok=True)
+ fn = os.path.join(args.result_dir, 'wav.scp')
+ f = open(fn, 'w')
+ with torch.no_grad():
+ for _, batch in tqdm(enumerate(test_data_loader)):
+ utts = batch["utts"]
+ assert len(utts) == 1, "inference mode only support batchsize 1"
+ text_token = batch["text_token"].to(device)
+ text_token_len = batch["text_token_len"].to(device)
+ tts_index = batch["tts_index"]
+ tts_text_token = batch["tts_text_token"].to(device)
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
+ speech_token = batch["speech_token"].to(device)
+ speech_token_len = batch["speech_token_len"].to(device)
+ speech_feat = batch["speech_feat"].to(device)
+ speech_feat_len = batch["speech_feat_len"].to(device)
+ utt_embedding = batch["utt_embedding"].to(device)
+ spk_embedding = batch["spk_embedding"].to(device)
+ if args.mode == 'sft':
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
+ else:
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
+ tts_speeches = []
+ for model_output in model.tts(**model_input):
+ tts_speeches.append(model_output['tts_speech'])
+ tts_speeches = torch.concat(tts_speeches, dim=1)
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
+ f.write('{} {}\n'.format(tts_key, tts_fn))
+ f.flush()
+ f.close()
+ logging.info('Result wav.scp saved in {}'.format(fn))
+
+
+if __name__ == '__main__':
+ logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
+ main()
diff --git a/CosyVoice/cosyvoice/bin/train.py b/CosyVoice/cosyvoice/bin/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e4016f5688507935ff5bd2a60dcd2a6a5295410
--- /dev/null
+++ b/CosyVoice/cosyvoice/bin/train.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import argparse
+import datetime
+import logging
+logging.getLogger('matplotlib').setLevel(logging.WARNING)
+from copy import deepcopy
+import os
+import torch
+import torch.distributed as dist
+import deepspeed
+
+from hyperpyyaml import load_hyperpyyaml
+
+from torch.distributed.elastic.multiprocessing.errors import record
+
+from cosyvoice.utils.losses import DPOLoss
+from cosyvoice.utils.executor import Executor
+from cosyvoice.utils.train_utils import (
+ init_distributed,
+ init_dataset_and_dataloader,
+ init_optimizer_and_scheduler,
+ init_summarywriter, save_model,
+ wrap_cuda_model, check_modify_and_save_config)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description='training your network')
+ parser.add_argument('--train_engine',
+ default='torch_ddp',
+ choices=['torch_ddp', 'deepspeed'],
+ help='Engine for paralleled training')
+ parser.add_argument('--model', required=True, help='model which will be trained')
+ parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument('--train_data', required=True, help='train data file')
+ parser.add_argument('--cv_data', required=True, help='cv data file')
+ parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
+ parser.add_argument('--checkpoint', help='checkpoint model')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--tensorboard_dir',
+ default='tensorboard',
+ help='tensorboard log dir')
+ parser.add_argument('--ddp.dist_backend',
+ dest='dist_backend',
+ default='nccl',
+ choices=['nccl', 'gloo'],
+ help='distributed backend')
+ parser.add_argument('--num_workers',
+ default=0,
+ type=int,
+ help='num of subprocess workers for reading')
+ parser.add_argument('--prefetch',
+ default=100,
+ type=int,
+ help='prefetch number')
+ parser.add_argument('--pin_memory',
+ action='store_true',
+ default=False,
+ help='Use pinned memory buffers used for reading')
+ parser.add_argument('--use_amp',
+ action='store_true',
+ default=False,
+ help='Use automatic mixed precision training')
+ parser.add_argument('--dpo',
+ action='store_true',
+ default=False,
+ help='Use Direct Preference Optimization')
+ parser.add_argument('--deepspeed.save_states',
+ dest='save_states',
+ default='model_only',
+ choices=['model_only', 'model+optimizer'],
+ help='save model/optimizer states')
+ parser.add_argument('--timeout',
+ default=60,
+ type=int,
+ help='timeout (in seconds) of cosyvoice_join.')
+ parser = deepspeed.add_config_arguments(parser)
+ args = parser.parse_args()
+ return args
+
+
+@record
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ # gan train has some special initialization logic
+ gan = True if args.model == 'hifigan' else False
+
+ override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
+ if gan is True:
+ override_dict.pop('hift')
+ try:
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
+ except Exception:
+ with open(args.config, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides=override_dict)
+ if gan is True:
+ configs['train_conf'] = configs['train_conf_gan']
+ configs['train_conf'].update(vars(args))
+
+ # Init env for ddp
+ init_distributed(args)
+
+ # Get dataset & dataloader
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
+ init_dataset_and_dataloader(args, configs, gan, args.dpo)
+
+ # Do some sanity checks and save config to arsg.model_dir
+ configs = check_modify_and_save_config(args, configs)
+
+ # Tensorboard summary
+ writer = init_summarywriter(args)
+
+ # load checkpoint
+ if args.dpo is True:
+ configs[args.model].forward = configs[args.model].forward_dpo
+ model = configs[args.model]
+ start_step, start_epoch = 0, -1
+ if args.checkpoint is not None:
+ if os.path.exists(args.checkpoint):
+ state_dict = torch.load(args.checkpoint, map_location='cpu')
+ model.load_state_dict(state_dict, strict=False)
+ if 'step' in state_dict:
+ start_step = state_dict['step']
+ if 'epoch' in state_dict:
+ start_epoch = state_dict['epoch']
+ else:
+ logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
+
+ # Dispatch model from cpu to gpu
+ model = wrap_cuda_model(args, model)
+
+ # Get optimizer & scheduler
+ model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
+ scheduler.set_step(start_step)
+ if scheduler_d is not None:
+ scheduler_d.set_step(start_step)
+
+ # Save init checkpoints
+ info_dict = deepcopy(configs['train_conf'])
+ info_dict['step'] = start_step
+ info_dict['epoch'] = start_epoch
+ save_model(model, 'init', info_dict)
+
+ # DPO related
+ if args.dpo is True:
+ ref_model = deepcopy(configs[args.model])
+ state_dict = torch.load(args.ref_model, map_location='cpu')
+ ref_model.load_state_dict(state_dict, strict=False)
+ dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
+ # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
+ ref_model = wrap_cuda_model(args, ref_model)
+ else:
+ ref_model, dpo_loss = None, None
+
+ # Get executor
+ executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
+ executor.step = start_step
+
+ # Init scaler, used for pytorch amp mixed precision training
+ scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
+ print('start step {} start epoch {}'.format(start_step, start_epoch))
+
+ # Start training loop
+ for epoch in range(start_epoch + 1, info_dict['max_epoch']):
+ executor.epoch = epoch
+ train_dataset.set_epoch(epoch)
+ dist.barrier()
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
+ if gan is True:
+ executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
+ writer, info_dict, scaler, group_join)
+ else:
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
+ dist.destroy_process_group(group_join)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/CosyVoice/cosyvoice/cli/__init__.py b/CosyVoice/cosyvoice/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/CosyVoice/cosyvoice/cli/cosyvoice.py b/CosyVoice/cosyvoice/cli/cosyvoice.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc443bed44c651a47492fc7e2142e3a88fb47627
--- /dev/null
+++ b/CosyVoice/cosyvoice/cli/cosyvoice.py
@@ -0,0 +1,194 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import time
+from typing import Generator
+from tqdm import tqdm
+from hyperpyyaml import load_hyperpyyaml
+from modelscope import snapshot_download
+import torch
+from cosyvoice.cli.frontend import CosyVoiceFrontEnd
+from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
+from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.class_utils import get_model_type
+
+
+class CosyVoice:
+
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
+ self.instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ self.fp16 = fp16
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
+ if not os.path.exists(hyper_yaml_path):
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
+ with open(hyper_yaml_path, 'r') as f:
+ configs = load_hyperpyyaml(f)
+ assert get_model_type(configs) != CosyVoice2Model, 'do not use {} for CosyVoice initialization!'.format(model_dir)
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ configs['allowed_special'])
+ self.sample_rate = configs['sample_rate']
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
+ load_jit, load_trt, fp16 = False, False, False
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ if load_jit:
+ self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+ '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+ '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+ if load_trt:
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+ trt_concurrent,
+ self.fp16)
+ del configs
+
+ def list_available_spks(self):
+ spks = list(self.frontend.spk2info.keys())
+ return spks
+
+ def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
+ assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
+ model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_speech_16k, self.sample_rate, '')
+ del model_input['text']
+ del model_input['text_len']
+ self.frontend.spk2info[zero_shot_spk_id] = model_input
+ return True
+
+ def save_spkinfo(self):
+ torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
+
+ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ model_input = self.frontend.frontend_sft(i, spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
+ logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
+ assert isinstance(self.model, CosyVoiceModel), 'inference_instruct is only implemented for CosyVoice!'
+ if self.instruct is False:
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+ def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
+ model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k, self.sample_rate)
+ start_time = time.time()
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
+
+
+class CosyVoice2(CosyVoice):
+
+ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
+ self.instruct = True if '-Instruct' in model_dir else False
+ self.model_dir = model_dir
+ self.fp16 = fp16
+ if not os.path.exists(model_dir):
+ model_dir = snapshot_download(model_dir)
+ hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
+ if not os.path.exists(hyper_yaml_path):
+ raise ValueError('{} not found!'.format(hyper_yaml_path))
+ with open(hyper_yaml_path, 'r') as f:
+ configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
+ assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
+ configs['feat_extractor'],
+ '{}/campplus.onnx'.format(model_dir),
+ '{}/speech_tokenizer_v2.onnx'.format(model_dir),
+ '{}/spk2info.pt'.format(model_dir),
+ configs['allowed_special'])
+ self.sample_rate = configs['sample_rate']
+ if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
+ load_jit, load_trt, fp16 = False, False, False
+ logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
+ self.model.load('{}/llm.pt'.format(model_dir),
+ '{}/flow.pt'.format(model_dir),
+ '{}/hift.pt'.format(model_dir))
+ if load_vllm:
+ self.model.load_vllm('{}/vllm'.format(model_dir))
+ if load_jit:
+ self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
+ if load_trt:
+ self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
+ '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
+ trt_concurrent,
+ self.fp16)
+ del configs
+
+ def inference_instruct(self, *args, **kwargs):
+ raise NotImplementedError('inference_instruct is not implemented for CosyVoice2!')
+
+ def inference_instruct2(self, tts_text, instruct_text, prompt_speech_16k, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
+ assert isinstance(self.model, CosyVoice2Model), 'inference_instruct2 is only implemented for CosyVoice2!'
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
+ model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id)
+ start_time = time.time()
+ logging.info('synthesis text {}'.format(i))
+ for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
+ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
+ yield model_output
+ start_time = time.time()
diff --git a/CosyVoice/cosyvoice/cli/frontend.py b/CosyVoice/cosyvoice/cli/frontend.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98b0d612e244bbf58a3a1d9312857055c2133ac
--- /dev/null
+++ b/CosyVoice/cosyvoice/cli/frontend.py
@@ -0,0 +1,215 @@
+# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+from typing import Generator
+import json
+import onnxruntime
+import torch
+import numpy as np
+import whisper
+from typing import Callable
+import torchaudio.compliance.kaldi as kaldi
+import torchaudio
+import os
+import re
+import inflect
+try:
+ import ttsfrd
+ use_ttsfrd = True
+except ImportError:
+ print("failed to import ttsfrd, use wetext instead")
+ from wetext import Normalizer as ZhNormalizer
+ from wetext import Normalizer as EnNormalizer
+ use_ttsfrd = False
+from cosyvoice.utils.file_utils import logging
+from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
+
+
+class CosyVoiceFrontEnd:
+
+ def __init__(self,
+ get_tokenizer: Callable,
+ feat_extractor: Callable,
+ campplus_model: str,
+ speech_tokenizer_model: str,
+ spk2info: str = '',
+ allowed_special: str = 'all'):
+ self.tokenizer = get_tokenizer()
+ self.feat_extractor = feat_extractor
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ option = onnxruntime.SessionOptions()
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ option.intra_op_num_threads = 1
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
+ "CPUExecutionProvider"])
+ if os.path.exists(spk2info):
+ self.spk2info = torch.load(spk2info, map_location=self.device)
+ else:
+ self.spk2info = {}
+ self.allowed_special = allowed_special
+ self.use_ttsfrd = use_ttsfrd
+ if self.use_ttsfrd:
+ self.frd = ttsfrd.TtsFrontendEngine()
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
+ 'failed to initialize ttsfrd resource'
+ self.frd.set_lang_type('pinyinvg')
+ else:
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False)
+ self.en_tn_model = EnNormalizer()
+ self.inflect_parser = inflect.engine()
+
+ def _extract_text_token(self, text):
+ if isinstance(text, Generator):
+ logging.info('get tts_text generator, will return _extract_text_token_generator!')
+ # NOTE add a dummy text_token_len for compatibility
+ return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
+ else:
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
+ return text_token, text_token_len
+
+ def _extract_text_token_generator(self, text_generator):
+ for text in text_generator:
+ text_token, _ = self._extract_text_token(text)
+ for i in range(text_token.shape[1]):
+ yield text_token[:, i: i + 1]
+
+ def _extract_speech_token(self, speech):
+ assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
+ speech_token = self.speech_tokenizer_session.run(None,
+ {self.speech_tokenizer_session.get_inputs()[0].name:
+ feat.detach().cpu().numpy(),
+ self.speech_tokenizer_session.get_inputs()[1].name:
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_token, speech_token_len
+
+ def _extract_spk_embedding(self, speech):
+ feat = kaldi.fbank(speech,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ embedding = self.campplus_session.run(None,
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
+ embedding = torch.tensor([embedding]).to(self.device)
+ return embedding
+
+ def _extract_speech_feat(self, speech):
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
+ speech_feat = speech_feat.unsqueeze(dim=0)
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
+ return speech_feat, speech_feat_len
+
+ def text_normalize(self, text, split=True, text_frontend=True):
+ if isinstance(text, Generator):
+ logging.info('get tts_text generator, will skip text_normalize!')
+ return [text]
+ if text_frontend is False or text == '':
+ return [text] if split is True else text
+ text = text.strip()
+ if self.use_ttsfrd:
+ texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
+ text = ''.join(texts)
+ else:
+ if contains_chinese(text):
+ text = self.zh_tn_model.normalize(text)
+ text = text.replace("\n", "")
+ text = replace_blank(text)
+ text = replace_corner_mark(text)
+ text = text.replace(".", "。")
+ text = text.replace(" - ", ",")
+ text = remove_bracket(text)
+ text = re.sub(r'[,,、]+$', '。', text)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ else:
+ text = self.en_tn_model.normalize(text)
+ text = spell_out_number(text, self.inflect_parser)
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
+ token_min_n=60, merge_len=20, comma_split=False))
+ texts = [i for i in texts if not is_only_punctuation(i)]
+ return texts if split is True else text
+
+ def frontend_sft(self, tts_text, spk_id):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ embedding = self.spk2info[spk_id]['embedding']
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
+ return model_input
+
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
+ if zero_shot_spk_id == '':
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
+ prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
+ if resample_rate == 24000:
+ # cosyvoice2, force speech_feat % speech_token = 2
+ token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
+ speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
+ speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
+ model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
+ else:
+ model_input = self.spk2info[zero_shot_spk_id]
+ model_input['text'] = tts_text_token
+ model_input['text_len'] = tts_text_token_len
+ return model_input
+
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id):
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k, resample_rate, zero_shot_spk_id)
+ # in cross lingual mode, we remove prompt in llm
+ del model_input['prompt_text']
+ del model_input['prompt_text_len']
+ del model_input['llm_prompt_speech_token']
+ del model_input['llm_prompt_speech_token_len']
+ return model_input
+
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
+ model_input = self.frontend_sft(tts_text, spk_id)
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
+ del model_input['llm_embedding']
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '
+
+