Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Optional, Literal | |
| import asyncio | |
| import time | |
| import hashlib | |
| import io | |
| # Import our utilities | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from config import get_settings | |
| from utils.model_loader import ModelManager | |
| from utils.image_processing import ( | |
| load_image_from_bytes, | |
| load_image_from_base64, | |
| array_to_base64, | |
| depth_to_colormap, | |
| create_side_by_side | |
| ) | |
| from utils.demo_depth import generate_smart_depth | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Dimensio API", | |
| description="Add Dimension to Everything - High-performance depth estimation and 3D visualization API", | |
| version="1.0.0" | |
| ) | |
| settings = get_settings() | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.CORS_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model manager | |
| model_manager = ModelManager() | |
| DEMO_MODE = False # Will be set to True if no models available | |
| # Request/Response models | |
| class DepthRequest(BaseModel): | |
| """Request model for depth estimation""" | |
| image: str # Base64 encoded image | |
| model: Literal["small", "large"] = "small" | |
| output_format: Literal["grayscale", "colormap", "both"] = "colormap" | |
| colormap: Literal["inferno", "viridis", "plasma", "turbo"] = "inferno" | |
| class DepthResponse(BaseModel): | |
| """Response model for depth estimation""" | |
| depth_map: str # Base64 encoded depth map | |
| metadata: dict | |
| processing_time_ms: float | |
| # Startup/shutdown events | |
| async def startup_event(): | |
| """Initialize models on startup""" | |
| print(">> Starting Dimensio API...") | |
| try: | |
| # Load small model (fast preview) | |
| small_model_path = Path(settings.MODEL_CACHE_DIR) / settings.DEPTH_MODEL_SMALL | |
| if small_model_path.exists(): | |
| model_manager.load_model( | |
| "small", | |
| str(small_model_path), | |
| use_gpu=settings.USE_GPU, | |
| use_tensorrt=settings.TRT_OPTIMIZATION | |
| ) | |
| print("[+] Small model loaded") | |
| else: | |
| print(f"[!] Small model not found: {small_model_path}") | |
| # Load large model (high quality) | |
| large_model_path = Path(settings.MODEL_CACHE_DIR) / settings.DEPTH_MODEL_LARGE | |
| if large_model_path.exists(): | |
| model_manager.load_model( | |
| "large", | |
| str(large_model_path), | |
| use_gpu=settings.USE_GPU, | |
| use_tensorrt=settings.TRT_OPTIMIZATION | |
| ) | |
| print("[+] Large model loaded") | |
| else: | |
| print(f"[!] Large model not found: {large_model_path}") | |
| if not model_manager.models: | |
| global DEMO_MODE | |
| DEMO_MODE = True | |
| print("\n[!] No models loaded - Running in DEMO MODE") | |
| print("Demo mode uses synthetic depth maps for testing the UI.") | |
| print("\nTo use real AI models:") | |
| print("1. Run: python download_models.py") | |
| print("2. Place ONNX models in models/cache/") | |
| print("3. Restart the server") | |
| except Exception as e: | |
| print(f"[X] Error loading models: {e}") | |
| print("Server will start but depth estimation will not work.") | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| print(">> Shutting down Depth Flow Pro API...") | |
| # Health check | |
| async def root(): | |
| """API health check""" | |
| return { | |
| "name": "Depth Flow Pro API", | |
| "version": "1.0.0", | |
| "status": "online", | |
| "models_loaded": list(model_manager.models.keys()) | |
| } | |
| async def health_check(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy", | |
| "models": { | |
| name: "loaded" for name in model_manager.models.keys() | |
| }, | |
| "gpu_enabled": settings.USE_GPU, | |
| "tensorrt_enabled": settings.TRT_OPTIMIZATION | |
| } | |
| # Depth estimation endpoints | |
| async def estimate_depth_preview(file: UploadFile = File(...)): | |
| """ | |
| Fast depth estimation using small model (preview quality) | |
| Optimized for speed, ~50-100ms on GPU | |
| """ | |
| try: | |
| start_time = time.time() | |
| # Load image | |
| image_bytes = await file.read() | |
| image = load_image_from_bytes(image_bytes) | |
| # Check if demo mode or use real model | |
| if DEMO_MODE: | |
| # Use synthetic depth for demo | |
| depth = generate_smart_depth(image) | |
| model_name = "demo" | |
| else: | |
| # Get small model | |
| model = model_manager.get_model("small") | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Small model not loaded. Please check server logs." | |
| ) | |
| # Run depth estimation | |
| depth = model.predict(image) | |
| model_name = "small" | |
| # Convert to colormap | |
| depth_colored = depth_to_colormap(depth) | |
| # Encode to base64 | |
| depth_base64 = array_to_base64(depth_colored, format='PNG') | |
| processing_time = (time.time() - start_time) * 1000 | |
| return DepthResponse( | |
| depth_map=depth_base64, | |
| metadata={ | |
| "model": model_name, | |
| "input_size": image.shape[:2], | |
| "output_size": depth.shape[:2], | |
| "demo_mode": DEMO_MODE | |
| }, | |
| processing_time_ms=round(processing_time, 2) | |
| ) | |
| except Exception as e: | |
| print(f"β Error: {type(e).__name__}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def estimate_depth_hq(file: UploadFile = File(...)): | |
| """ | |
| High-quality depth estimation using large model | |
| Slower but more accurate, ~500-1500ms on GPU | |
| """ | |
| try: | |
| start_time = time.time() | |
| # Load image | |
| image_bytes = await file.read() | |
| image = load_image_from_bytes(image_bytes) | |
| # Check if demo mode or use real model | |
| if DEMO_MODE: | |
| # Use synthetic depth for demo | |
| depth = generate_smart_depth(image) | |
| model_name = "demo (HQ)" | |
| else: | |
| # Get large model | |
| model = model_manager.get_model("large") | |
| if model is None: | |
| # Fallback to small model if large not available | |
| model = model_manager.get_model("small") | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="No models loaded. Please check server logs." | |
| ) | |
| model_name = "small (fallback)" | |
| else: | |
| model_name = "large" | |
| # Run depth estimation | |
| depth = model.predict(image) | |
| # Convert to colormap | |
| depth_colored = depth_to_colormap(depth) | |
| # Encode to base64 | |
| depth_base64 = array_to_base64(depth_colored, format='PNG') | |
| processing_time = (time.time() - start_time) * 1000 | |
| return DepthResponse( | |
| depth_map=depth_base64, | |
| metadata={ | |
| "model": model_name, | |
| "input_size": image.shape[:2], | |
| "output_size": depth.shape[:2], | |
| "demo_mode": DEMO_MODE | |
| }, | |
| processing_time_ms=round(processing_time, 2) | |
| ) | |
| except Exception as e: | |
| print(f"β Error: {type(e).__name__}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def estimate_depth(request: DepthRequest): | |
| """ | |
| Depth estimation with custom options | |
| Accepts base64 encoded image | |
| """ | |
| try: | |
| start_time = time.time() | |
| # Load image from base64 | |
| image = load_image_from_base64(request.image) | |
| # Get model | |
| model = model_manager.get_model(request.model) | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model '{request.model}' not loaded" | |
| ) | |
| # Run depth estimation | |
| depth = model.predict(image) | |
| # Process output based on format | |
| if request.output_format == "grayscale": | |
| output = (depth * 255).astype('uint8') | |
| depth_base64 = array_to_base64(output, format='PNG') | |
| elif request.output_format == "colormap": | |
| import cv2 | |
| colormap_dict = { | |
| "inferno": cv2.COLORMAP_INFERNO, | |
| "viridis": cv2.COLORMAP_VIRIDIS, | |
| "plasma": cv2.COLORMAP_PLASMA, | |
| "turbo": cv2.COLORMAP_TURBO | |
| } | |
| depth_colored = depth_to_colormap(depth, colormap_dict[request.colormap]) | |
| depth_base64 = array_to_base64(depth_colored, format='PNG') | |
| else: # both | |
| side_by_side = create_side_by_side(image, depth, colormap=True) | |
| depth_base64 = array_to_base64(side_by_side, format='PNG') | |
| processing_time = (time.time() - start_time) * 1000 | |
| return DepthResponse( | |
| depth_map=depth_base64, | |
| metadata={ | |
| "model": request.model, | |
| "output_format": request.output_format, | |
| "colormap": request.colormap, | |
| "input_size": image.shape[:2], | |
| "output_size": depth.shape[:2] | |
| }, | |
| processing_time_ms=round(processing_time, 2) | |
| ) | |
| except Exception as e: | |
| print(f"β Error: {type(e).__name__}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # WebSocket for streaming | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """ | |
| WebSocket endpoint for real-time depth estimation | |
| Supports streaming multiple images | |
| """ | |
| await websocket.accept() | |
| try: | |
| while True: | |
| # Receive image data | |
| data = await websocket.receive_json() | |
| if data.get("action") == "estimate": | |
| start_time = time.time() | |
| # Load image | |
| image = load_image_from_base64(data["image"]) | |
| # Get model | |
| model_name = data.get("model", "small") | |
| model = model_manager.get_model(model_name) | |
| if model is None: | |
| await websocket.send_json({ | |
| "error": f"Model '{model_name}' not loaded" | |
| }) | |
| continue | |
| # Send progress update | |
| await websocket.send_json({ | |
| "status": "processing", | |
| "progress": 50 | |
| }) | |
| # Run depth estimation | |
| depth = model.predict(image) | |
| # Convert to colormap | |
| depth_colored = depth_to_colormap(depth) | |
| depth_base64 = array_to_base64(depth_colored, format='PNG') | |
| processing_time = (time.time() - start_time) * 1000 | |
| # Send result | |
| await websocket.send_json({ | |
| "status": "complete", | |
| "depth_map": depth_base64, | |
| "processing_time_ms": round(processing_time, 2) | |
| }) | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| await websocket.send_json({"error": str(e)}) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host=settings.HOST, | |
| port=settings.PORT, | |
| reload=settings.DEBUG | |
| ) | |