Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import spaces | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from idf.models.lit_a_denoising import LitADenoising | |
| # --- Model Loading --- | |
| # Load Pretrained IDF Model | |
| config = OmegaConf.load("configs/models/idfnet.yaml") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_pth = "pretrained_models/idf_g_15.ckpt" | |
| # Load the model | |
| model = LitADenoising.load_from_checkpoint(model_pth, strict=True, map_location=device, | |
| denoiser_config=config.params.denoiser_config, | |
| loss_config=config.params.loss_config) | |
| model.to(device) | |
| model.eval().freeze() | |
| # --- Main Inference Function --- | |
| def infer( | |
| img, | |
| adaptive_iter, | |
| max_iter = None, | |
| ): | |
| if img is None: | |
| return None | |
| x = TF.to_tensor(img).unsqueeze(0).to(device) | |
| x = model(x, adaptive_iter, max_iter) | |
| x = torch.clamp(x, 0.0, 1.0) | |
| return TF.to_pil_image(x.squeeze(0)) | |
| # --- Examples and UI Layout --- | |
| examples = [] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| #edit_text{ | |
| margin-top: -62px !important | |
| } | |
| .badge-row { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| gap: 12px; | |
| flex-wrap: nowrap; /* force single line; change to wrap if needed */ | |
| margin-top: 8px; | |
| } | |
| .badge-row a { text-decoration: none; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML( | |
| """ | |
| <div align='center'> | |
| <img src="https://dongjinkim9.github.io/projects/idf/assets/idf_logo.png" alt="IDF Logo" width="200" style="display: block; margin: 0 auto;"> | |
| <h1>IDF: <span style="color: red;">I</span>terative <span style="color: red;">D</span>ynamic <span style="color: red;">F</span>iltering Networks <br>for Generalizable Image Denoising</h1> | |
| <div class="badge-row"> | |
| <a href="https://arxiv.org/abs/2508.19649" target="_blank"> | |
| <img src="https://img.shields.io/badge/Arxiv-📄Paper-8A2BE2" alt="Arxiv" /> | |
| </a> | |
| <a href="https://dongjinkim9.github.io/projects/idf" target="_blank"> | |
| <img src="https://img.shields.io/badge/Project-📖Page-8A2BE2" alt="Project Page" /> | |
| </a> | |
| <a href="https://github.com/dongjinkim9/IDF" target="_blank"> | |
| <img src="https://img.shields.io/badge/Github-💻Code-8A2BE2" alt="Github Code" /> | |
| </a> | |
| </div> | |
| </div> | |
| <br> | |
| <blockquote> | |
| <p><i>Even though IDF is trained with extremely limited data (e.g., a single-level Gaussian noise), | |
| it generalizes effectively to diverse unseen noise types and levels with only ~0.04M parameters.</i></p> | |
| </blockquote> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Noisy Image", show_label=True, type="pil") | |
| result = gr.Image(label="Denoised Image", show_label=True, type="pil") | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear", variant="secondary") | |
| run_button = gr.Button("Denoise!", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| adaptive_iter = gr.Checkbox(value=False, label="Dynamic Iteration Control") | |
| with gr.Row(): | |
| max_iter = gr.Slider(minimum=1, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| label="Maximum Denoising Iteration") | |
| gr.Examples( | |
| examples=[ | |
| ["assets/demo/noisy/DSLR.png", False, 10], | |
| ["assets/demo/noisy/Smartphone.png", False, 10], | |
| ["assets/demo/noisy/Salt_Pepper.png", False, 10], | |
| ["assets/demo/noisy/Gaussian.png", False, 10], | |
| ["assets/demo/noisy/Spatial_Gaussian.png", False, 10], | |
| ["assets/demo/noisy/Poisson.png", False, 10], | |
| ["assets/demo/noisy/Speckle.png", False, 10], | |
| ["assets/demo/noisy/Mixture.png", False, 10] | |
| ], | |
| inputs=[ | |
| input_image, | |
| adaptive_iter, | |
| max_iter | |
| ], | |
| outputs=[result], | |
| fn=infer, | |
| cache_examples=True, | |
| cache_mode='lazy', | |
| label="Examples (Noise Types)", | |
| example_labels=["📸DSLR","📱Smartphone","🧂Salt & Pepper","🌫️Gaussian","🌀Spatial Gaussian","🎲Poisson","✨Speckle", "🧪Mixture"], | |
| ) | |
| # Define the event triggers for the run button. | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| input_image, | |
| adaptive_iter, | |
| max_iter | |
| ], | |
| outputs=[result], | |
| ) | |
| # Define the clear button functionality. | |
| clear_button.click( | |
| fn=lambda: [None, False, 10], | |
| outputs=[input_image, adaptive_iter, max_iter] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |