Spaces:
Build error
Build error
| import math | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from skimage.measure.simple_metrics import compare_psnr | |
| from torchvision import models | |
| def weights_init_kaiming(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1: | |
| nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') | |
| elif classname.find('Linear') != -1: | |
| nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') | |
| elif classname.find('BatchNorm') != -1: | |
| # nn.init.uniform(m.weight.data, 1.0, 0.02) | |
| m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025) | |
| nn.init.constant(m.bias.data, 0.0) | |
| class VGG19_PercepLoss(nn.Module): | |
| """ Calculates perceptual loss in vgg19 space | |
| """ | |
| def __init__(self, _pretrained_=True): | |
| super(VGG19_PercepLoss, self).__init__() | |
| self.vgg = models.vgg19(pretrained=_pretrained_).features | |
| for param in self.vgg.parameters(): | |
| param.requires_grad_(False) | |
| def get_features(self, image, layers=None): | |
| if layers is None: | |
| layers = {'30': 'conv5_2'} # may add other layers | |
| features = {} | |
| x = image | |
| for name, layer in self.vgg._modules.items(): | |
| x = layer(x) | |
| if name in layers: | |
| features[layers[name]] = x | |
| return features | |
| def forward(self, pred, true, layer='conv5_2'): | |
| true_f = self.get_features(true) | |
| pred_f = self.get_features(pred) | |
| return torch.mean((true_f[layer]-pred_f[layer])**2) | |
| def batch_PSNR(img, imclean, data_range): | |
| Img = img.data.cpu().numpy().astype(np.float32) | |
| Iclean = imclean.data.cpu().numpy().astype(np.float32) | |
| PSNR = 0 | |
| for i in range(Img.shape[0]): | |
| PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range) | |
| return (PSNR/Img.shape[0]) | |
| def data_augmentation(image, mode): | |
| out = np.transpose(image, (1,2,0)) | |
| #out = image | |
| if mode == 0: | |
| # original | |
| out = out | |
| elif mode == 1: | |
| # flip up and down | |
| out = np.flipud(out) | |
| elif mode == 2: | |
| # rotate counterwise 90 degree | |
| out = np.rot90(out) | |
| elif mode == 3: | |
| # rotate 90 degree and flip up and down | |
| out = np.rot90(out) | |
| out = np.flipud(out) | |
| elif mode == 4: | |
| # rotate 180 degree | |
| out = np.rot90(out, k=2) | |
| elif mode == 5: | |
| # rotate 180 degree and flip | |
| out = np.rot90(out, k=2) | |
| out = np.flipud(out) | |
| elif mode == 6: | |
| # rotate 270 degree | |
| out = np.rot90(out, k=3) | |
| elif mode == 7: | |
| # rotate 270 degree and flip | |
| out = np.rot90(out, k=3) | |
| out = np.flipud(out) | |
| return np.transpose(out, (2,0,1)) | |
| #return out | |