""" VLM Soft Biometrics - Gradio Interface A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models. """ import os import gradio as gr import torch import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont import base64 from io import BytesIO import traceback from huggingface_hub import snapshot_download from utils.face_detector import FaceDetector # Class definitions from src.model import MTLModel from utils.commons import get_backbone_pe from utils.task_config import Task TASKS = [ Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None), Task(name='Gender', class_labels=["Male", "Female"], criterion=None), Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None) ] CLASSES = [ ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], ["M", "F"], ["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"] ] # Global variables for model and detector model = None transform = None detector = None device = None current_ckpt_dir = None CHECKPOINTS_DIR = './checkpoints/' MODEL_REPO_ID = "Antuke/FaR-FT-PE" def scan_checkpoints(ckpt_dir): """Scans a directory for .pt or .pth files.""" if not os.path.exists(ckpt_dir): print(f"Warning: Checkpoint directory not found: {ckpt_dir}") return [], None try: ckpt_files = [ os.path.join(ckpt_dir, f) for f in sorted(os.listdir(ckpt_dir)) if f.endswith(('.pt', '.pth')) ] except Exception as e: print(f"Error scanning checkpoint directory {ckpt_dir}: {e}") return [], None choices_list = [(os.path.basename(f), f) for f in ckpt_files] default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt') if default_ckpt_path in ckpt_files: return choices_list, default_ckpt_path elif ckpt_files: return choices_list, ckpt_files[0] else: print(f"No checkpoints found in {ckpt_dir}") return [], None def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"): """Load and configure model.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False) model = MTLModel(backbone,device=device,tasks=TASKS,use_lora=True,use_deep_head=True, use_mtl_lora=('mtlora' in ckpt_dir), ) print(f'loading from {ckpt_dir}') model.load_model(filepath=ckpt_dir,map_location=device) return model,transform def load_model_and_update_status(model_filepath): """Wrapper function to load a model """ global model, current_ckpt_dir if model_filepath is None or model_filepath == "": return "No checkpoint selected." # Check if this model filepath is already loaded if model is not None and model_filepath == current_ckpt_dir: status = f"Model already loaded: {os.path.basename(model_filepath)}" print(status) return status gr.Info(f"Loading model: {os.path.basename(model_filepath)}...") try: init_model(ckpt_dir=model_filepath, detection_confidence=0.5) current_ckpt_dir = model_filepath # Set global path on successful load status = f"Successfully loaded: {os.path.basename(model_filepath)}" gr.Info("Model loaded successfully!") print(status) return status except Exception as e: traceback.print_exc() status = f"Failed to load {os.path.basename(model_filepath)}: {e}" gr.Info(f"Error: {status}") print(f"ERROR: {status}") return status def predict(model, image): """Make predictions for age, gender, and emotion.""" with torch.no_grad(): results = model(image) age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion'] # Get probabilities using softmax age_probs = torch.softmax(age_logits, dim=-1) gender_probs = torch.softmax(gender_logits, dim=-1) emotion_probs = torch.softmax(emotion_logits, dim=-1) ages = torch.argmax(age_logits, dim=-1).cpu().tolist() genders = torch.argmax(gender_logits, dim=-1).cpu().tolist() emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist() results = [] for i in range(len(ages)): # Get all probabilities for each class age_all_probs = { CLASSES[0][j]: float(age_probs[i][j].cpu().detach()) for j in range(len(CLASSES[0])) } gender_all_probs = { CLASSES[1][j]: float(gender_probs[i][j].cpu().detach()) for j in range(len(CLASSES[1])) } emotion_all_probs = { CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach()) for j in range(len(CLASSES[2])) } results.append({ 'age': { 'predicted_class': CLASSES[0][ages[i]], 'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()), 'all_probabilities': age_all_probs }, 'gender': { 'predicted_class': CLASSES[1][genders[i]], 'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()), 'all_probabilities': gender_all_probs }, 'emotion': { 'predicted_class': CLASSES[2][emotions[i]], 'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()), 'all_probabilities': emotion_all_probs } }) return results def get_centroid_weighted_age(probs): """ Using centroids of age group we calculate an age regression number using an average weight based on predicted probability distribution """ probs = list(probs.values()) centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80] age = 0 for i,p in enumerate(probs): age += p * centroids[i] return age def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5): """Initialize model and detector.""" global model, transform, detector, device print(f"\n{'='*60}") print(f"INITIALIZING MODEL: {ckpt_dir}") print(f"{'='*60}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") if not os.path.exists(ckpt_dir): error_msg = f"Model weights not found: {ckpt_dir}." print(f"ERROR: {error_msg}") raise FileNotFoundError(error_msg) print(f"Model weights found: {ckpt_dir}") # Load the perception encoder model, transform = load_model(ckpt_dir= ckpt_dir,device= device) model.eval() print(device) model.to(device) detector = FaceDetector(confidence_threshold=detection_confidence) print("✓ Model and detector initialized successfully") print(f"{'='*60}\n") def process_image(image, selected_checkpoint_path): """ Process an uploaded image and return predictions with annotated image. Args: image: PIL Image or numpy array selected_checkpoint_path: The path from the checkpoint dropdown Returns: tuple: (annotated_image, results_html) """ if image is None: return None, "
Please upload an image
" # Ensure model is initialized # this check ensures the selected model is loaded. if model is None or selected_checkpoint_path != current_ckpt_dir: print(f"Model mismatch or not loaded. Selected: {selected_checkpoint_path}, Current: {current_ckpt_dir}") status = load_model_and_update_status(selected_checkpoint_path) if "Failed" in status or "Error" in status: return image, f"Model Error: {status}
" try: # Convert PIL to OpenCV format (BGR) for the detector if isinstance(image, Image.Image): img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) else: img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Create a PIL copy to draw annotations on img_pil_annotated = image.copy() draw = ImageDraw.Draw(img_pil_annotated) faces = detector.detect(img_cv, pad_rect=True) if faces is None or len(faces) == 0: return image, "No faces detected in the image
" # --- Process detected faces --- crops_pil = [] face_data = [] for idx, (crop, confidence, bbox) in enumerate(faces): crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) crop_pil = Image.fromarray(crop_rgb) crops_pil.append(crop_pil) # Resize crop to 336x336 for display, to match model input size crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS) face_data.append({ 'bbox': bbox, 'detection_confidence': float(confidence), 'crop_image': crop_resized }) # --- Batch transform and predict --- crop_tensors = [transform(crop_pil) for crop_pil in crops_pil] batch_tensor = torch.stack(crop_tensors).to(device) predictions = predict(model, batch_tensor) # Combine face data with predictions for face, pred in zip(face_data, predictions): face['predictions'] = pred # --- Create annotated image (using PIL) --- for idx, face in enumerate(face_data): bbox = face['bbox'] pred = face['predictions'] x, y, w, h = bbox # --- Calculate Adaptive Font --- font_size_ratio = 0.08 min_font_size = 12 max_font_size = 48 adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size)) try: font = ImageFont.load_default(size=adaptive_font_size) except IOError: font = ImageFont.load_default() # --- Draw Bounding Box --- draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2) # --- Prepare Text Lines --- lines_to_draw = [] # Age age_label = pred['age']['predicted_class'] age_conf = pred['age']['predicted_confidence'] lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)") # Gender gen_label = pred['gender']['predicted_class'] gen_conf = pred['gender']['predicted_confidence'] lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)") # Emotion emo_label = pred['emotion']['predicted_class'] emo_conf = pred['emotion']['predicted_confidence'] lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)") # --- Calculate total height of the text block --- line_spacing = 10 total_text_height = 0 for line in lines_to_draw: _left, top, _right, bottom = draw.textbbox((0, 0), line, font=font) total_text_height += (bottom - top) + line_spacing # --- Place text ABOVE or BELOW the box --- if y - total_text_height > 0: text_y = y - line_spacing for line in reversed(lines_to_draw): left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls") # anchor left-baseline draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black") draw.text((x, text_y), line, font=font, fill="white", anchor="ls") text_y = top - line_spacing # Move y-position up for the next line else: text_y = y + h + line_spacing for line in lines_to_draw: left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt") draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black") draw.text((x, text_y), line, font=font, fill="white", anchor="lt") text_y = bottom + line_spacing # Helper function to convert PIL image to base64 def pil_to_base64(img_pil): buffered = BytesIO() img_pil.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" results_html = f"""Error processing image: {str(e)}
" def create_interface(checkpoint_list, default_checkpoint, initial_status): """Create and configure the Gradio interface.""" custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } .output-html { max-height: none !important; overflow-y: auto; } :root { --primary-color: #4f46e5; --success-color: #10b981; --text-primary: var(--body-text-color); --text-secondary: var(--body-text-color-subdued); --background-dark: var(--background-fill-primary); --background-darker: var(--background-fill-secondary); --border-color: var(--border-color-primary); } .results-container { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: var(--background-darker); padding: 20px; border-radius: 12px; color: var(--text-primary); } .results-container h2 { color: var(--text-primary); margin-bottom: 20px; } .face-count { display: inline-block; background: var(--primary-color); color: white; padding: 4px 12px; border-radius: 20px; font-size: 0.9rem; font-weight: 500; margin-left: 8px; } .face-card { background: var(--background-dark); border-radius: 8px; padding: 20px; margin-top: 15px; border: 1px solid var(--border-color); display: flex; gap: 20px; align-items: flex-start; } .face-header { font-size: 1rem; font-weight: 600; margin-bottom: 20px; color: var(--text-primary); } .face-image-left { flex-shrink: 0; width: 336px; height: 336px; background: var(--background-darker); border-radius: 8px; overflow: hidden; border: 1px solid var(--border-color); } .face-image-left img { width: 100%; height: 100%; object-fit: cover; } .face-predictions-right { flex: 1; display: flex; flex-direction: column; gap: 10px; } .predictions-horizontal { display: flex; flex-direction: row; gap: 30px; justify-content: space-between; } .prediction-section { flex: 1; min-width: 0; } .prediction-category-label { font-size: 0.8rem; font-weight: 700; text-transform: uppercase; letter-spacing: 0.5px; color: var(--primary-color); margin-bottom: 8px; border-bottom: 2px solid var(--primary-color); padding-bottom: 4px; } .probabilities-list { display: flex; flex-direction: column; gap: 6px; } .probability-item { display: grid; grid-template-columns: 70px 1fr 55px; align-items: center; gap: 8px; padding: 4px 6px; border-radius: 4px; } .probability-item.predicted { background: rgba(79, 70, 229, 0.2); border-left: 3px solid var(--primary-color); padding-left: 8px; } .prob-class { font-size: 0.8rem; font-weight: 600; color: var(--text-primary); word-wrap: break-word; /* Ensure long class names wrap */ } .probability-item.predicted .prob-class { color: var(--primary-color); font-weight: 700; } .prob-bar-container { height: 6px; background: var(--border-color); border-radius: 3px; overflow: hidden; } .prob-bar { height: 100%; background: linear-gradient(90deg, var(--primary-color), var(--success-color)); border-radius: 3px; transition: width 0.6s ease; } .probability-item.predicted .prob-bar { background: var(--primary-color); } .prob-percentage { font-size: 0.75rem; font-weight: 500; color: var(--text-secondary); text-align: right; } .probability-item.predicted .prob-percentage { color: var(--primary-color); font-weight: 700; } @media (max-width: 1200px) { .predictions-horizontal { flex-direction: column; gap: 15px; } } @media (max-width: 900px) { .face-card { flex-direction: column; } .face-image-left { width: 100%; max-width: 336px; margin: 0 auto; } .probability-item { grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */ } .prob-class { font-size: 0.75rem; } } """ # Create interface with gr.Blocks(css=custom_css, title="Face Classification System", theme=gr.themes.Default()) as demo: with gr.Row(): gr.Markdown("# Face Classification System") # --- Model Selection --- with gr.Row(): with gr.Column(scale=3): checkpoint_dropdown = gr.Dropdown( label="Select Model Checkpoint", choices=checkpoint_list, value=default_checkpoint, ) with gr.Column(scale=2): model_status_text = gr.Textbox( label="Model Status", value=initial_status, interactive=False, ) # Features | Instructions with gr.Row(): with gr.Column(scale=1): gr.Markdown(""" ### Features - **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average - **Gender Classification**: M/F - **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral) - **Automatic Face Detection**: Detects and analyzes multiple faces - **Detailed Probability Distributions**: View confidence for all classes """) with gr.Column(scale=1): gr.Markdown(""" ### Instructions 1. (Optional) Select a model checkpoint from the dropdown. 2. Upload an image or capture from webcam (or select an example below) 3. Click "Classify Image" 4. View detected faces with age, gender, and emotion predictions below \n Demo video of usage of this space: https://youtu.be/V6-9QTf1xaQ """) # Upload Image | Annotated Image with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Upload Image", type="pil", sources=["upload", "webcam"], height=400 ) with gr.Column(scale=1): output_image = gr.Image( label="Annotated Image", type="pil", height=400 ) with gr.Row(): with gr.Column(scale=1): analyze_btn = gr.Button( "Classify Image", variant="primary", size="lg" ) # Dynamically load example images from example directory example_dir = "example" example_images = [] if os.path.exists(example_dir): try: example_images = [ os.path.join(example_dir, f) for f in sorted(os.listdir(example_dir)) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) ] except Exception as e: print(f"Error reading example images from {example_dir}: {e}") if example_images: gr.Markdown("### 📸 Try with example images") gr.Examples( examples=example_images, inputs=input_image, cache_examples=False ) # Results section with gr.Row(): with gr.Column(scale=1): output_html = gr.HTML( label="Classification Results", elem_classes="output-html" ) # Event handlers analyze_btn.click( fn=process_image, inputs=[input_image, checkpoint_dropdown], outputs=[output_image, output_html] ) checkpoint_dropdown.change( fn=load_model_and_update_status, inputs=[checkpoint_dropdown], outputs=[model_status_text] ) return demo # Application Startup print("="*60) print("VLM SOFT BIOMETRICS - GRADIO INTERFACE") print("="*60) # --- Model download from HF Repo --- print(f"Downloading model weights from {MODEL_REPO_ID} to {CHECKPOINTS_DIR}...") os.makedirs(CHECKPOINTS_DIR, exist_ok=True) try: snapshot_download( repo_id=MODEL_REPO_ID, local_dir=CHECKPOINTS_DIR, allow_patterns=["*.pt", "*.pth"], # Grabs all weight files local_dir_use_symlinks=False, ) print("Model download complete.") except Exception as e: print(f"CRITICAL: Failed to download models from Hub. {e}") traceback.print_exc() checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR) if not checkpoint_list: print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.") else: print(f"Found checkpoints: {len(checkpoint_list)} file(s).") print(f"Default checkpoint: {default_checkpoint}") # --- Try to initialize default model --- initial_status_msg = "No default model found. Please select one." if default_checkpoint: print(f"\nInitializing default model: {default_checkpoint}") # This will load the model AND set current_ckpt_dir # It now correctly uses the local file path initial_status_msg = load_model_and_update_status(default_checkpoint) print(initial_status_msg) else: print("Warning: No default model to load.") # --- Create interface --- print("Creating Gradio interface...") demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg) print("✓ Interface created successfully!") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface") parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/", help="Path to the checkpoint directory (will be populated from HF Hub)") parser.add_argument("--detection_confidence", type=float, default=0.5, help="Confidence threshold for face detection") parser.add_argument("--port", type=int, default=7860, help="Port to run the Gradio app") parser.add_argument("--share", action="store_true", help="Create a public share link") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server name/IP to bind to") args = parser.parse_args() CHECKPOINTS_DIR = args.ckpt_dir print(f"\nLaunching server on {args.server_name}:{args.port}") print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}") print("="*60) demo.launch( share=args.share, server_name=args.server_name, server_port=args.port, show_error=True, )