IDF / idf /models /lit_a_denoising.py
dongjin-kim's picture
Upload 47 files
207cadb verified
from typing import Mapping, Any
import torch
from idf.utils.common import instantiate_from_config
from idf.utils.metrics import calculate_psnr_pt, calculate_ssim_pt
from torchvision.transforms.functional import center_crop
from idf.models.lit_denoising import LitDenoising
from idf.utils.misc import const_like
import numpy as np
class LitADenoising(LitDenoising):
def __init__(
self,
data_config: Mapping[str, Any],
denoiser_config: Mapping[str, Any],
loss_config: Mapping[str, Any],
optimizer_config: Mapping[str, Any],
scheduler_config: Mapping[str, Any] = None,
misc_config: Mapping[str, Any] = None,
):
super().__init__(data_config, denoiser_config, loss_config,
optimizer_config, scheduler_config,misc_config,)
self.model = instantiate_from_config(denoiser_config)
self.misc_config = misc_config
if self.misc_config.compile:
self.model = torch.compile(self.model)
self.loss = instantiate_from_config(loss_config)
self.optimizer_config = optimizer_config
self.scheduler_config = scheduler_config
self.data_config = data_config
self.val_dataset_names = [k for k in self.data_config.validate.keys()]
# data normalization
self.data_scale = np.float32(data_config.norm.sigma_data) / np.float32(data_config.norm.raw_std)
self.data_bias = np.float32(data_config.norm.mu_data) - np.float32(data_config.norm.raw_mean) * self.data_scale
self.save_hyperparameters()
def forward(self, noisy, adaptive_iter=False, max_iter=None, alpha_schedule=None):
x = self.normalize(noisy)
pred = self.model(x, adaptive_iter=adaptive_iter, max_iter=max_iter, alpha_schedule=alpha_schedule)
pred = self.normalize(pred, reverse=True)
return pred
def normalize(self, x, reverse=False):
if not reverse:
if self.data_scale is not None:
x = x * const_like(x, self.data_scale).reshape(1, -1, 1, 1)
if self.data_bias is not None:
x = x + const_like(x, self.data_bias).reshape(1, -1, 1, 1)
else:
if self.data_scale is not None:
x = x - const_like(x, self.data_bias).reshape(1, -1, 1, 1)
if self.data_bias is not None:
x = x / const_like(x, self.data_scale).reshape(1, -1, 1, 1)
return x
@torch.no_grad()
def get_input(self, batch, config):
x = batch[config.input_key]
y = batch[config.target_key]
x = self.normalize(x)
y = self.normalize(y)
return x, y
def training_step(self, batch, batch_idx):
x, y = self.get_input(batch, self.data_config.train)
self.log("bs", self.global_batch_size, prog_bar=True, logger=False)
self.log('lr', self.get_lr(), prog_bar=True, logger=False)
losses = dict()
pred = self.model(x)
losses['train/loss'] = self.loss(pred, y)
losses['train/total'] = sum(losses.values())
self.log_dict(losses, prog_bar=True)
return losses['train/total']
def on_validation_start(self):
self.sampled_images = []
self.sample_steps_val = 50
print(f"[Inference Settings] {self.misc_config.adaptive_iteration=}, {self.misc_config.max_iteration=}")
def validation_step(self, batch, batch_idx, dataloader_idx=0):
val_name = self.val_dataset_names[dataloader_idx]
val_config = self.data_config.validate[val_name]
self._validation_step(batch, batch_idx, val_config, suffix=f"_{val_name}")
def _validation_step(self, batch, batch_idx, val_config, suffix=""):
x, y = self.get_input(batch, val_config)
assert x.shape[0] == 1
pred = self(x, adaptive_iter=self.misc_config.adaptive_iteration,
max_iter=self.misc_config.max_iteration,
alpha_schedule=self.misc_config.get('alpha_schedule'))
pred = self.normalize(pred, reverse=True)
pred = torch.clamp(pred, 0.0, 1.0)
x = self.normalize(x, reverse=True)
y = self.normalize(y, reverse=True)
# Evaluate metrics.
losses = {}
losses[f'val{suffix}/psnr'] = calculate_psnr_pt(y, pred, 0, test_y_channel=False).mean()
losses[f'val{suffix}/ssim'] = calculate_ssim_pt(y, pred, 0, test_y_channel=False).mean()
self.log_dict(losses, sync_dist=True, prog_bar=True, add_dataloader_idx=False)
if batch_idx % 500 == 0:
self.sampled_images.append(center_crop(x, (256,256))[0].cpu())
self.sampled_images.append(center_crop(y, (256,256))[0].cpu())
self.sampled_images.append(center_crop(pred, (256,256))[0].cpu())