File size: 4,970 Bytes
207cadb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Mapping, Any
import torch
from idf.utils.common import instantiate_from_config
from idf.utils.metrics import calculate_psnr_pt, calculate_ssim_pt
from torchvision.transforms.functional import center_crop
from idf.models.lit_denoising import LitDenoising
from idf.utils.misc import const_like
import numpy as np

class LitADenoising(LitDenoising):
    def __init__(

        self,

        data_config: Mapping[str, Any],

        denoiser_config: Mapping[str, Any],

        loss_config: Mapping[str, Any],

        optimizer_config: Mapping[str, Any],

        scheduler_config: Mapping[str, Any] = None,

        misc_config: Mapping[str, Any] = None,

    ):
        super().__init__(data_config, denoiser_config, loss_config, 
                         optimizer_config, scheduler_config,misc_config,)

        self.model = instantiate_from_config(denoiser_config)
        
        self.misc_config = misc_config
        if self.misc_config.compile:
            self.model = torch.compile(self.model)

        self.loss = instantiate_from_config(loss_config)
        self.optimizer_config = optimizer_config
        self.scheduler_config = scheduler_config
        self.data_config = data_config

        self.val_dataset_names = [k for k in self.data_config.validate.keys()]
        
        # data normalization
        self.data_scale = np.float32(data_config.norm.sigma_data) / np.float32(data_config.norm.raw_std)
        self.data_bias = np.float32(data_config.norm.mu_data) - np.float32(data_config.norm.raw_mean) * self.data_scale
        
        self.save_hyperparameters()

    def forward(self, noisy, adaptive_iter=False, max_iter=None, alpha_schedule=None):
        x = self.normalize(noisy)
        pred = self.model(x, adaptive_iter=adaptive_iter, max_iter=max_iter, alpha_schedule=alpha_schedule)
        pred = self.normalize(pred, reverse=True)
        return pred
        
    def normalize(self, x, reverse=False):
        if not reverse:
            if self.data_scale is not None:
                x = x * const_like(x, self.data_scale).reshape(1, -1, 1, 1)
            if self.data_bias is not None:
                x = x + const_like(x, self.data_bias).reshape(1, -1, 1, 1)
        else:
            if self.data_scale is not None:
                x = x - const_like(x, self.data_bias).reshape(1, -1, 1, 1)
            if self.data_bias is not None:    
                x = x / const_like(x, self.data_scale).reshape(1, -1, 1, 1)

        return x
        
    @torch.no_grad()
    def get_input(self, batch, config):
        x = batch[config.input_key]
        y = batch[config.target_key]

        x = self.normalize(x)
        y = self.normalize(y)
        return x, y
   
    def training_step(self, batch, batch_idx):
        x, y = self.get_input(batch, self.data_config.train)
        self.log("bs", self.global_batch_size, prog_bar=True, logger=False)
        self.log('lr', self.get_lr(), prog_bar=True, logger=False)

        losses = dict()
        pred = self.model(x)
        losses['train/loss'] = self.loss(pred, y)
        losses['train/total'] = sum(losses.values())
        self.log_dict(losses, prog_bar=True)
        return losses['train/total']
    
    def on_validation_start(self):
        self.sampled_images = []
        self.sample_steps_val = 50
        print(f"[Inference Settings] {self.misc_config.adaptive_iteration=}, {self.misc_config.max_iteration=}")

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        val_name = self.val_dataset_names[dataloader_idx]
        val_config = self.data_config.validate[val_name]
        self._validation_step(batch, batch_idx, val_config, suffix=f"_{val_name}")
    
    def _validation_step(self, batch, batch_idx, val_config, suffix=""):
        x, y = self.get_input(batch, val_config)
        assert x.shape[0] == 1

        pred = self(x, adaptive_iter=self.misc_config.adaptive_iteration, 
                    max_iter=self.misc_config.max_iteration,
                    alpha_schedule=self.misc_config.get('alpha_schedule'))

        pred = self.normalize(pred, reverse=True)
        pred = torch.clamp(pred, 0.0, 1.0)

        x = self.normalize(x, reverse=True)
        y = self.normalize(y, reverse=True)
        
        # Evaluate metrics.
        losses = {}
        losses[f'val{suffix}/psnr'] = calculate_psnr_pt(y, pred, 0, test_y_channel=False).mean()
        losses[f'val{suffix}/ssim'] = calculate_ssim_pt(y, pred, 0, test_y_channel=False).mean()

        self.log_dict(losses, sync_dist=True, prog_bar=True, add_dataloader_idx=False)
        
        if batch_idx % 500 == 0:
            self.sampled_images.append(center_crop(x, (256,256))[0].cpu())
            self.sampled_images.append(center_crop(y, (256,256))[0].cpu())
            self.sampled_images.append(center_crop(pred, (256,256))[0].cpu())