Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,890 Bytes
207cadb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
# https://github.com/XPixelGroup/BasicSR
import torch
import functools
from torch import nn as nn
from torch.nn import functional as F
_reduction_modes = ['none', 'mean', 'sum']
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
def get_local_weights(residual, ksize):
"""Get local weights for generating the artifact map of LDL.
It is only called by the `get_refined_artifact_map` function.
Args:
residual (Tensor): Residual between predicted and ground truth images.
ksize (Int): size of the local window.
Returns:
Tensor: weight for each pixel to be discriminated as an artifact pixel
"""
pad = (ksize - 1) // 2
residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
return pixel_level_weight
def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
"""Calculate the artifact map of LDL
(Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
Args:
img_gt (Tensor): ground truth images.
img_output (Tensor): output images given by the optimizing model.
img_ema (Tensor): output images given by the ema model.
ksize (Int): size of the local window.
Returns:
overall_weight: weight for each pixel to be discriminated as an artifact pixel
(calculated based on both local and global observations).
"""
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
overall_weight = patch_level_weight * pixel_level_weight
overall_weight[residual_sr < residual_ema] = 0
return overall_weight
@weighted_loss
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero. Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
if reduction not in ['mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
def forward(self, pred, weight=None):
if weight is None:
y_weight = None
x_weight = None
else:
y_weight = weight[:, :, :-1, :]
x_weight = weight[:, :, :, :-1]
y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
loss = x_diff + y_diff
return loss |