Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn import functional as F | |
| from .dit import DiffusionTransformer | |
| from .adp import UNet1d | |
| from .sampling import sample | |
| import math | |
| from model.base import BaseModule | |
| import pdb | |
| target_length = 1536 | |
| def pad_and_create_mask(matrix, target_length): | |
| T = matrix.shape[2] | |
| if T > target_length: | |
| raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length)) | |
| padding_size = target_length - T | |
| padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0) | |
| mask = torch.ones((1, target_length)) | |
| mask[:, T:] = 0 # Set the padding part to 0 | |
| return padded_matrix.to(matrix.device), mask.to(matrix.device) | |
| class Stable_Diffusion(BaseModule): | |
| def __init__(self): | |
| super(Stable_Diffusion, self).__init__() | |
| self.diffusion = DiffusionTransformer( | |
| io_channels=80, | |
| # input_concat_dim=80, | |
| embed_dim=768, | |
| # cond_token_dim=target_length, | |
| depth=24, | |
| num_heads=24, | |
| project_cond_tokens=False, | |
| transformer_type="continuous_transformer", | |
| ) | |
| # self.diffusion = UNet1d( | |
| # in_channels=80, | |
| # channels=256, | |
| # resnet_groups=16, | |
| # kernel_multiplier_downsample=2, | |
| # multipliers=[4, 4, 4, 5, 5], | |
| # factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短 | |
| # num_blocks=[2, 2, 2, 2], | |
| # attentions=[1, 3, 3, 3, 3], | |
| # attention_heads=16, | |
| # attention_multiplier=4, | |
| # use_nearest_upsample=False, | |
| # use_skip_scale=True, | |
| # use_context_time=True | |
| # ) | |
| self.rng = torch.quasirandom.SobolEngine(1, scramble=True) | |
| def forward(self, mu, mask, n_timesteps): | |
| # pdb.set_trace() | |
| mask = mask.squeeze(1) | |
| # noise = torch.randn_like(mu).to(mu.device) | |
| # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) | |
| # extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask} | |
| extra_args = {"mask": mask} | |
| fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args) | |
| return fakes | |
| def compute_loss(self, x0, mask, mu): | |
| # pdb.set_trace() | |
| t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device) | |
| alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
| alphas = alphas[:, None, None] | |
| sigmas = sigmas[:, None, None] | |
| noise = torch.randn_like(x0) | |
| noised_inputs = x0 * alphas + noise * sigmas | |
| targets = mu * alphas - x0 * sigmas | |
| mask = mask.squeeze(1) | |
| # mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length) | |
| # output = self.diffusion(noised_inputs, t, cross_attn_cond=mu, | |
| # cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1) | |
| output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1) | |
| return self.mse_loss(output, targets, mask), output | |
| def mse_loss(self, output, targets, mask): | |
| mse_loss = F.mse_loss(output, targets, reduction='none') | |
| if mask.ndim == 2 and mse_loss.ndim == 3: | |
| mask = mask.unsqueeze(1) | |
| if mask.shape[1] != mse_loss.shape[1]: | |
| mask = mask.repeat(1, mse_loss.shape[1], 1) | |
| mse_loss = mse_loss[mask] | |
| mse_loss = mse_loss.mean() | |
| return mse_loss |