import gradio as gr import torch import spaces from omegaconf import OmegaConf import torch import torchvision.transforms.functional as TF from idf.models.lit_a_denoising import LitADenoising # --- Model Loading --- # Load Pretrained IDF Model config = OmegaConf.load("configs/models/idfnet.yaml") device = "cuda" if torch.cuda.is_available() else "cpu" model_pth = "pretrained_models/idf_g_15.ckpt" # Load the model model = LitADenoising.load_from_checkpoint(model_pth, strict=True, map_location=device, denoiser_config=config.params.denoiser_config, loss_config=config.params.loss_config) model.to(device) model.eval().freeze() # --- Main Inference Function --- @spaces.GPU() def infer( img, adaptive_iter, max_iter = None, ): if img is None: return None x = TF.to_tensor(img).unsqueeze(0).to(device) x = model(x, adaptive_iter, max_iter) x = torch.clamp(x, 0.0, 1.0) return TF.to_pil_image(x.squeeze(0)) # --- Examples and UI Layout --- examples = [] css = """ #col-container { margin: 0 auto; max-width: 1024px; } #edit_text{ margin-top: -62px !important } .badge-row { display: flex; justify-content: center; align-items: center; gap: 12px; flex-wrap: nowrap; /* force single line; change to wrap if needed */ margin-top: 8px; } .badge-row a { text-decoration: none; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML( """
""" ) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Noisy Image", show_label=True, type="pil") result = gr.Image(label="Denoised Image", show_label=True, type="pil") with gr.Row(): clear_button = gr.Button("Clear", variant="secondary") run_button = gr.Button("Denoise!", variant="primary") with gr.Accordion("Advanced Settings", open=False): adaptive_iter = gr.Checkbox(value=False, label="Dynamic Iteration Control") with gr.Row(): max_iter = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Maximum Denoising Iteration") gr.Examples( examples=[ ["assets/demo/noisy/DSLR.png", False, 10], ["assets/demo/noisy/Smartphone.png", False, 10], ["assets/demo/noisy/Salt_Pepper.png", False, 10], ["assets/demo/noisy/Gaussian.png", False, 10], ["assets/demo/noisy/Spatial_Gaussian.png", False, 10], ["assets/demo/noisy/Poisson.png", False, 10], ["assets/demo/noisy/Speckle.png", False, 10], ["assets/demo/noisy/Mixture.png", False, 10] ], inputs=[ input_image, adaptive_iter, max_iter ], outputs=[result], fn=infer, cache_examples=True, cache_mode='lazy', label="Examples (Noise Types)", example_labels=["📸DSLR","📱Smartphone","🧂Salt & Pepper","🌫️Gaussian","🌀Spatial Gaussian","🎲Poisson","✨Speckle", "🧪Mixture"], ) # Define the event triggers for the run button. run_button.click( fn=infer, inputs=[ input_image, adaptive_iter, max_iter ], outputs=[result], ) # Define the clear button functionality. clear_button.click( fn=lambda: [None, False, 10], outputs=[input_image, adaptive_iter, max_iter] ) if __name__ == "__main__": demo.launch()Even though IDF is trained with extremely limited data (e.g., a single-level Gaussian noise), it generalizes effectively to diverse unseen noise types and levels with only ~0.04M parameters.