IDF / app.py
dongjin-kim's picture
Update app.py
7ea993e verified
import gradio as gr
import torch
import spaces
from omegaconf import OmegaConf
import torch
import torchvision.transforms.functional as TF
from idf.models.lit_a_denoising import LitADenoising
# --- Model Loading ---
# Load Pretrained IDF Model
config = OmegaConf.load("configs/models/idfnet.yaml")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_pth = "pretrained_models/idf_g_15.ckpt"
# Load the model
model = LitADenoising.load_from_checkpoint(model_pth, strict=True, map_location=device,
denoiser_config=config.params.denoiser_config,
loss_config=config.params.loss_config)
model.to(device)
model.eval().freeze()
# --- Main Inference Function ---
@spaces.GPU()
def infer(
img,
adaptive_iter,
max_iter = None,
):
if img is None:
return None
x = TF.to_tensor(img).unsqueeze(0).to(device)
x = model(x, adaptive_iter, max_iter)
x = torch.clamp(x, 0.0, 1.0)
return TF.to_pil_image(x.squeeze(0))
# --- Examples and UI Layout ---
examples = []
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
#edit_text{
margin-top: -62px !important
}
.badge-row {
display: flex;
justify-content: center;
align-items: center;
gap: 12px;
flex-wrap: nowrap; /* force single line; change to wrap if needed */
margin-top: 8px;
}
.badge-row a { text-decoration: none; }
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div align='center'>
<img src="https://dongjinkim9.github.io/projects/idf/assets/idf_logo.png" alt="IDF Logo" width="200" style="display: block; margin: 0 auto;">
<h1>IDF: <span style="color: red;">I</span>terative <span style="color: red;">D</span>ynamic <span style="color: red;">F</span>iltering Networks <br>for Generalizable Image Denoising</h1>
<div class="badge-row">
<a href="https://arxiv.org/abs/2508.19649" target="_blank">
<img src="https://img.shields.io/badge/Arxiv-📄Paper-8A2BE2" alt="Arxiv" />
</a>
<a href="https://dongjinkim9.github.io/projects/idf" target="_blank">
<img src="https://img.shields.io/badge/Project-📖Page-8A2BE2" alt="Project Page" />
</a>
<a href="https://github.com/dongjinkim9/IDF" target="_blank">
<img src="https://img.shields.io/badge/Github-💻Code-8A2BE2" alt="Github Code" />
</a>
</div>
</div>
<br>
<blockquote>
<p><i>Even though IDF is trained with extremely limited data (e.g., a single-level Gaussian noise),
it generalizes effectively to diverse unseen noise types and levels with only ~0.04M parameters.</i></p>
</blockquote>
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Noisy Image", show_label=True, type="pil")
result = gr.Image(label="Denoised Image", show_label=True, type="pil")
with gr.Row():
clear_button = gr.Button("Clear", variant="secondary")
run_button = gr.Button("Denoise!", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
adaptive_iter = gr.Checkbox(value=False, label="Dynamic Iteration Control")
with gr.Row():
max_iter = gr.Slider(minimum=1,
maximum=20,
value=10,
step=1,
label="Maximum Denoising Iteration")
gr.Examples(
examples=[
["assets/demo/noisy/DSLR.png", False, 10],
["assets/demo/noisy/Smartphone.png", False, 10],
["assets/demo/noisy/Salt_Pepper.png", False, 10],
["assets/demo/noisy/Gaussian.png", False, 10],
["assets/demo/noisy/Spatial_Gaussian.png", False, 10],
["assets/demo/noisy/Poisson.png", False, 10],
["assets/demo/noisy/Speckle.png", False, 10],
["assets/demo/noisy/Mixture.png", False, 10]
],
inputs=[
input_image,
adaptive_iter,
max_iter
],
outputs=[result],
fn=infer,
cache_examples=True,
cache_mode='lazy',
label="Examples (Noise Types)",
example_labels=["📸DSLR","📱Smartphone","🧂Salt & Pepper","🌫️Gaussian","🌀Spatial Gaussian","🎲Poisson","✨Speckle", "🧪Mixture"],
)
# Define the event triggers for the run button.
run_button.click(
fn=infer,
inputs=[
input_image,
adaptive_iter,
max_iter
],
outputs=[result],
)
# Define the clear button functionality.
clear_button.click(
fn=lambda: [None, False, 10],
outputs=[input_image, adaptive_iter, max_iter]
)
if __name__ == "__main__":
demo.launch()