Spaces:
Paused
Paused
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import time | |
| from .valle_ar_trainer import ValleARTrainer, make_pad_mask | |
| class ValleNARTrainer(ValleARTrainer): | |
| def __init__(self, args=None, cfg=None): | |
| super().__init__(args, cfg) | |
| print("simple NAR") | |
| self.top1_accuracies = { | |
| 1: [], | |
| 2: [], | |
| 3: [], | |
| 4: [], | |
| 5: [], | |
| 6: [], | |
| 7: [], | |
| } | |
| self.top5_accuracies = { | |
| 1: [], | |
| 2: [], | |
| 3: [], | |
| 4: [], | |
| 5: [], | |
| 6: [], | |
| 7: [], | |
| } | |
| self.top10_accuracies = { | |
| 1: [], | |
| 2: [], | |
| 3: [], | |
| 4: [], | |
| 5: [], | |
| 6: [], | |
| 7: [], | |
| } | |
| def _build_model(self): | |
| from .valle_nar import ValleNAR | |
| return ValleNAR(**self.cfg.model) | |
| def _train_step(self, batch): | |
| # inference codec | |
| """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
| speech: [B, T] | |
| speech_len: [B] | |
| phone_ids: [B, T] | |
| phone_lens: [B] | |
| """ | |
| device = self.accelerator.device | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(device) | |
| with torch.no_grad(): | |
| if self.cfg.use_speechtokenizer: | |
| # Extract discrete codes from SpeechTokenizer | |
| # 16k | |
| vq_id = self.codec_encoder.encode( | |
| batch["speech"].unsqueeze(1) | |
| ) # [B,T] -> (n_q, B, T) | |
| # RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens | |
| # RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer | |
| # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding | |
| # wav = self.codec_encoder.decode(vq_id) | |
| # torchaudio.save('a.wav', wav[0].cpu(), 16000) | |
| # # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers | |
| # wav = model.decode(codes[i: (j + 1)], st=i) | |
| else: | |
| # using encodec, 24k | |
| vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
| vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
| 0, 1 | |
| ) | |
| # recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
| # torchaudio.save('a.wav', recovered_audio[0], 16000) | |
| # vq_id: [8, B, T//320] | |
| batch["speech"] = vq_id | |
| batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x | |
| assert batch["speech_len"].max() <= batch["speech"].shape[-1] | |
| phone_mask = 1 - make_pad_mask( | |
| batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False | |
| ).to(torch.long) | |
| speech_mask = 1 - make_pad_mask( | |
| batch["speech_len"], max_len=batch["speech"].size(-1) | |
| ).to(torch.long) | |
| np.random.seed(int(time.time()) - 5 * self.accelerator.process_index) | |
| if hasattr(self.cfg.train, "dropout"): | |
| dropout = self.cfg.train.dropout | |
| else: | |
| dropout = 0.0 | |
| out = self.model( | |
| phone_ids=batch["phone_ids"], | |
| phone_mask=phone_mask, | |
| target_ids=batch["speech"], | |
| target_mask=speech_mask, | |
| dropout=dropout, | |
| ) | |
| loss = out.loss | |
| self.accelerator.log( | |
| {f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc}, | |
| step=self.step, | |
| ) | |
| self.accelerator.log( | |
| {f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc}, | |
| step=self.step, | |
| ) | |
| self.accelerator.log( | |
| {f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc}, | |
| step=self.step, | |
| ) | |
| # if hasattr(out, 'top1_acc'): | |
| # idx = out.target_quantization_layer | |
| # self.top1_accuracies[idx].append(out.top1_acc) | |
| # self.top5_accuracies[idx].append(out.top5_acc) | |
| # self.top10_accuracies[idx].append(out.top10_acc) | |
| # if len(self.top1_accuracies[idx]) >= 160: | |
| # breakpoint() | |
| # if self.accelerator.is_main_process: | |
| # print(loss) | |
| return loss | |
| def _test_step(self, batch): | |
| # inference codec | |
| """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
| speech: [B, T] | |
| speech_len: [B] | |
| phone_ids: [B, T] | |
| phone_lens: [B] | |
| """ | |
| import torchaudio | |
| device = self.accelerator.device | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(device) | |
| with torch.no_grad(): | |
| if self.cfg.use_speechtokenizer: | |
| # Extract discrete codes from SpeechTokenizer | |
| # 16k | |
| vq_id = self.codec_encoder.encode( | |
| batch["speech"].unsqueeze(1) | |
| ) # [B,1,T] -> (n_q, B, T) | |
| # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding | |
| # wav = self.codec_encoder.decode(vq_id) | |
| # torchaudio.save('a.wav', wav[0].cpu(), 16000) | |
| else: | |
| vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
| vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
| 0, 1 | |
| ) | |
| # recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)]) | |
| # recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
| # torchaudio.save('a.wav', recovered_audio[0], 16000) | |
| # vq_id: [8, B, T//200] | |
| # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) | |
| # recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
| # recovered_audio.shape: torch.Size([1, 1, 50200]) | |
| batch["speech"] = vq_id | |
| # save gt | |
| if self.cfg.use_speechtokenizer: | |
| recovered_audio = self.codec_encoder.decode(vq_id) | |
| else: | |
| recovered_audio = self.codec_encoder.decode( | |
| [(vq_id.transpose(0, 1), None)] | |
| ) | |
| torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) | |
| self.model.eval() | |
| out_vq_ids = self.model.sample_hf( | |
| phone_ids=batch["phone_ids"][:1], | |
| prompt_ids=batch["speech"][:, :1, :150], | |
| first_stage_ids=batch["speech"][0, :1, 150:], | |
| ) | |
| # breakpoint() | |
| # out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1) | |
| # reconstruct form tokens | |
| if self.cfg.use_speechtokenizer: | |
| recovered_audio = self.codec_encoder.decode(out_vq_ids) | |
| else: | |
| recovered_audio = self.codec_encoder.decode( | |
| [(out_vq_ids.transpose(0, 1)[:1], None)] | |
| ) | |
| torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) | |
| breakpoint() | |