Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Настройки | |
| use_custom_weights = True | |
| custom_weights_path = hf_hub_download( | |
| repo_id="focuzz/depth-estimation", | |
| filename="unet_weights.pth" | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Загрузка пайплайна | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "prs-eth/marigold-v1-0", | |
| custom_pipeline="marigold_depth_estimation", | |
| torch_dtype=dtype | |
| ).to(device) | |
| # Загрузка дообученных весов | |
| if use_custom_weights: | |
| state_dict = torch.load(custom_weights_path, map_location=device) | |
| prefix = "unet.conv_in." if any(k.startswith("unet.conv_in.") for k in state_dict) else "conv_in." | |
| conv_in_dict = { | |
| k.replace(prefix, ""): v | |
| for k, v in state_dict.items() | |
| if k.startswith(prefix) | |
| } | |
| pipe.unet.conv_in.load_state_dict(conv_in_dict) | |
| print("Загружены дообученные веса conv_in из:", custom_weights_path) | |
| # Добавление overlay-текста | |
| def add_overlay(image: Image.Image, label: str) -> Image.Image: | |
| image = image.copy() | |
| draw = ImageDraw.Draw(image) | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| draw.text((10, 10), label, fill="white", font=font) | |
| return image | |
| # Генерация галереи из примеров | |
| TARGET_SIZE = (768, 768) | |
| def normalize_depth(depth_np): | |
| d = np.copy(depth_np) | |
| d_min = np.percentile(d, 1) | |
| d_max = np.percentile(d, 99) | |
| d = np.clip((d - d_min) / (d_max - d_min), 0, 1) | |
| return (d * 255).astype(np.uint8) | |
| def generate_gallery(): | |
| example_files = ["example1.jpg", "example2.jpg", "example3.jpg", "example4.jpg"] | |
| rgbs = [] | |
| depths_gray = [] | |
| depths_color = [] | |
| for path in example_files: | |
| if not os.path.exists(path): | |
| continue | |
| rgb = Image.open(path).convert("RGB").resize(TARGET_SIZE) | |
| with torch.no_grad(): | |
| output = pipe( | |
| rgb, | |
| denoising_steps=4, | |
| ensemble_size=5, | |
| processing_res=768, | |
| match_input_res=True, | |
| batch_size=0, | |
| color_map="Spectral", | |
| show_progress_bar=False, | |
| ) | |
| depth_np = output.depth_np | |
| gray_normalized = normalize_depth(depth_np) | |
| depth_gray = Image.fromarray(gray_normalized).convert("RGB").resize(TARGET_SIZE, Image.BILINEAR) | |
| depth_color = output.depth_colored.resize(TARGET_SIZE, Image.BILINEAR) | |
| rgbs.append(add_overlay(rgb, "RGB")) | |
| depths_gray.append(add_overlay(depth_gray, "Глубина (серая)")) | |
| depths_color.append(add_overlay(depth_color, "Глубина (цветная)")) | |
| return rgbs + depths_color + depths_gray | |
| # Интерфейс Blocks с галереей и инференсом | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Генерация карт глубины") | |
| gr.Markdown( | |
| "Модель основана на Marigold (ETH), дообучена на indoor-сценах из NYUv2. " | |
| "Сохраняет способность обрабатывать произвольные изображения благодаря наличию оригинальных U-Net весов." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Загрузите RGB изображение", type="pil") | |
| denoise = gr.Slider(1, 50, value=4, step=1, label="Шаги денойзинга") | |
| ensemble = gr.Slider(1, 10, value=5, step=1, label="Размер ансамбля (количество запусков для одной картинки)") | |
| resolution = gr.Slider(256, 1024, value=768, step=64, label="Разрешение обработки изображений") | |
| match_res = gr.Checkbox(value=True, label="Сохранять исходное разрешение") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Карта глубины") | |
| def predict_depth(image, denoising_steps, ensemble_size, processing_res, match_input_res): | |
| with torch.no_grad(): | |
| output = pipe( | |
| image, | |
| denoising_steps=denoising_steps, | |
| ensemble_size=ensemble_size, | |
| processing_res=processing_res, | |
| match_input_res=match_input_res, | |
| batch_size=0, | |
| color_map="Spectral", | |
| show_progress_bar=False, | |
| ) | |
| return output.depth_colored | |
| submit_btn = gr.Button("Выполнить предсказание") | |
| submit_btn.click( | |
| predict_depth, | |
| inputs=[input_image, denoise, ensemble, resolution, match_res], | |
| outputs=output_image | |
| ) | |
| gr.Markdown("### Примеры:") | |
| gallery = gr.Gallery(label="Сравнение RGB и Глубины", columns=4) | |
| demo.load(fn=generate_gallery, outputs=gallery) | |
| demo.launch(ssr_mode=False) |