|
|
"""
|
|
|
VLM Soft Biometrics - Gradio Interface
|
|
|
A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models.
|
|
|
"""
|
|
|
import os
|
|
|
import gradio as gr
|
|
|
import torch
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
import base64
|
|
|
from io import BytesIO
|
|
|
import traceback
|
|
|
from huggingface_hub import snapshot_download
|
|
|
from utils.face_detector import FaceDetector
|
|
|
|
|
|
|
|
|
from src.model import MTLModel
|
|
|
from utils.commons import get_backbone_pe
|
|
|
from utils.task_config import Task
|
|
|
|
|
|
|
|
|
TASKS = [
|
|
|
Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None),
|
|
|
Task(name='Gender', class_labels=["Male", "Female"], criterion=None),
|
|
|
Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None)
|
|
|
]
|
|
|
CLASSES = [
|
|
|
["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"],
|
|
|
["M", "F"],
|
|
|
["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"]
|
|
|
]
|
|
|
|
|
|
|
|
|
model = None
|
|
|
transform = None
|
|
|
detector = None
|
|
|
device = None
|
|
|
current_ckpt_dir = None
|
|
|
CHECKPOINTS_DIR = './checkpoints/'
|
|
|
MODEL_REPO_ID = "Antuke/FaR-FT-PE"
|
|
|
|
|
|
def scan_checkpoints(ckpt_dir):
|
|
|
"""Scans a directory for .pt or .pth files."""
|
|
|
if not os.path.exists(ckpt_dir):
|
|
|
print(f"Warning: Checkpoint directory not found: {ckpt_dir}")
|
|
|
return [], None
|
|
|
|
|
|
try:
|
|
|
ckpt_files = [
|
|
|
os.path.join(ckpt_dir, f)
|
|
|
for f in sorted(os.listdir(ckpt_dir))
|
|
|
if f.endswith(('.pt', '.pth'))
|
|
|
]
|
|
|
except Exception as e:
|
|
|
print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
|
|
|
return [], None
|
|
|
|
|
|
|
|
|
choices_list = [(os.path.basename(f), f) for f in ckpt_files]
|
|
|
|
|
|
default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt')
|
|
|
|
|
|
if default_ckpt_path in ckpt_files:
|
|
|
return choices_list, default_ckpt_path
|
|
|
elif ckpt_files:
|
|
|
return choices_list, ckpt_files[0]
|
|
|
else:
|
|
|
print(f"No checkpoints found in {ckpt_dir}")
|
|
|
return [], None
|
|
|
|
|
|
def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
|
|
|
"""Load and configure model."""
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
|
|
|
model = MTLModel(backbone,device=device,tasks=TASKS,use_lora=True,use_deep_head=True,
|
|
|
use_mtl_lora=('mtlora' in ckpt_dir),
|
|
|
)
|
|
|
print(f'loading from {ckpt_dir}')
|
|
|
model.load_model(filepath=ckpt_dir,map_location=device)
|
|
|
return model,transform
|
|
|
|
|
|
def load_model_and_update_status(model_filepath):
|
|
|
"""Wrapper function to load a model """
|
|
|
global model, current_ckpt_dir
|
|
|
|
|
|
if model_filepath is None or model_filepath == "":
|
|
|
return "No checkpoint selected."
|
|
|
|
|
|
|
|
|
if model is not None and model_filepath == current_ckpt_dir:
|
|
|
status = f"Model already loaded: {os.path.basename(model_filepath)}"
|
|
|
print(status)
|
|
|
return status
|
|
|
|
|
|
gr.Info(f"Loading model: {os.path.basename(model_filepath)}...")
|
|
|
try:
|
|
|
|
|
|
init_model(ckpt_dir=model_filepath, detection_confidence=0.5)
|
|
|
|
|
|
current_ckpt_dir = model_filepath
|
|
|
status = f"Successfully loaded: {os.path.basename(model_filepath)}"
|
|
|
gr.Info("Model loaded successfully!")
|
|
|
print(status)
|
|
|
return status
|
|
|
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
|
status = f"Failed to load {os.path.basename(model_filepath)}: {e}"
|
|
|
gr.Info(f"Error: {status}")
|
|
|
print(f"ERROR: {status}")
|
|
|
return status
|
|
|
|
|
|
def predict(model, image):
|
|
|
"""Make predictions for age, gender, and emotion."""
|
|
|
with torch.no_grad():
|
|
|
results = model(image)
|
|
|
|
|
|
age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion']
|
|
|
|
|
|
age_probs = torch.softmax(age_logits, dim=-1)
|
|
|
gender_probs = torch.softmax(gender_logits, dim=-1)
|
|
|
emotion_probs = torch.softmax(emotion_logits, dim=-1)
|
|
|
|
|
|
ages = torch.argmax(age_logits, dim=-1).cpu().tolist()
|
|
|
genders = torch.argmax(gender_logits, dim=-1).cpu().tolist()
|
|
|
emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist()
|
|
|
|
|
|
results = []
|
|
|
for i in range(len(ages)):
|
|
|
|
|
|
age_all_probs = {
|
|
|
CLASSES[0][j]: float(age_probs[i][j].cpu().detach())
|
|
|
for j in range(len(CLASSES[0]))
|
|
|
}
|
|
|
gender_all_probs = {
|
|
|
CLASSES[1][j]: float(gender_probs[i][j].cpu().detach())
|
|
|
for j in range(len(CLASSES[1]))
|
|
|
}
|
|
|
emotion_all_probs = {
|
|
|
CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach())
|
|
|
for j in range(len(CLASSES[2]))
|
|
|
}
|
|
|
|
|
|
results.append({
|
|
|
'age': {
|
|
|
'predicted_class': CLASSES[0][ages[i]],
|
|
|
'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()),
|
|
|
'all_probabilities': age_all_probs
|
|
|
},
|
|
|
'gender': {
|
|
|
'predicted_class': CLASSES[1][genders[i]],
|
|
|
'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()),
|
|
|
'all_probabilities': gender_all_probs
|
|
|
},
|
|
|
'emotion': {
|
|
|
'predicted_class': CLASSES[2][emotions[i]],
|
|
|
'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()),
|
|
|
'all_probabilities': emotion_all_probs
|
|
|
}
|
|
|
})
|
|
|
|
|
|
return results
|
|
|
|
|
|
def get_centroid_weighted_age(probs):
|
|
|
"""
|
|
|
Using centroids of age group we calculate an age regression number
|
|
|
using an average weight based on predicted probability distribution
|
|
|
"""
|
|
|
probs = list(probs.values())
|
|
|
centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80]
|
|
|
age = 0
|
|
|
|
|
|
for i,p in enumerate(probs):
|
|
|
age += p * centroids[i]
|
|
|
|
|
|
return age
|
|
|
|
|
|
|
|
|
def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
|
|
|
"""Initialize model and detector."""
|
|
|
global model, transform, detector, device
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"INITIALIZING MODEL: {ckpt_dir}")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
if not os.path.exists(ckpt_dir):
|
|
|
error_msg = f"Model weights not found: {ckpt_dir}."
|
|
|
print(f"ERROR: {error_msg}")
|
|
|
raise FileNotFoundError(error_msg)
|
|
|
|
|
|
print(f"Model weights found: {ckpt_dir}")
|
|
|
|
|
|
|
|
|
model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
|
|
|
model.eval()
|
|
|
print(device)
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
detector = FaceDetector(confidence_threshold=detection_confidence)
|
|
|
|
|
|
print("โ Model and detector initialized successfully")
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
def process_image(image, selected_checkpoint_path):
|
|
|
"""
|
|
|
Process an uploaded image and return predictions with annotated image.
|
|
|
|
|
|
Args:
|
|
|
image: PIL Image or numpy array
|
|
|
selected_checkpoint_path: The path from the checkpoint dropdown
|
|
|
|
|
|
Returns:
|
|
|
tuple: (annotated_image, results_html)
|
|
|
"""
|
|
|
if image is None:
|
|
|
return None, "<p style='color: red;'>Please upload an image</p>"
|
|
|
|
|
|
|
|
|
|
|
|
if model is None or selected_checkpoint_path != current_ckpt_dir:
|
|
|
print(f"Model mismatch or not loaded. Selected: {selected_checkpoint_path}, Current: {current_ckpt_dir}")
|
|
|
status = load_model_and_update_status(selected_checkpoint_path)
|
|
|
if "Failed" in status or "Error" in status:
|
|
|
return image, f"<p style'color: red;'>Model Error: {status}</p>"
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
if isinstance(image, Image.Image):
|
|
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
|
else:
|
|
|
img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
|
|
img_pil_annotated = image.copy()
|
|
|
draw = ImageDraw.Draw(img_pil_annotated)
|
|
|
|
|
|
faces = detector.detect(img_cv, pad_rect=True)
|
|
|
|
|
|
if faces is None or len(faces) == 0:
|
|
|
return image, "<p style='color: orange;'>No faces detected in the image</p>"
|
|
|
|
|
|
|
|
|
crops_pil = []
|
|
|
face_data = []
|
|
|
|
|
|
for idx, (crop, confidence, bbox) in enumerate(faces):
|
|
|
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
|
|
crop_pil = Image.fromarray(crop_rgb)
|
|
|
crops_pil.append(crop_pil)
|
|
|
|
|
|
|
|
|
crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS)
|
|
|
|
|
|
face_data.append({
|
|
|
'bbox': bbox,
|
|
|
'detection_confidence': float(confidence),
|
|
|
'crop_image': crop_resized
|
|
|
})
|
|
|
|
|
|
|
|
|
crop_tensors = [transform(crop_pil) for crop_pil in crops_pil]
|
|
|
batch_tensor = torch.stack(crop_tensors).to(device)
|
|
|
|
|
|
predictions = predict(model, batch_tensor)
|
|
|
|
|
|
|
|
|
for face, pred in zip(face_data, predictions):
|
|
|
face['predictions'] = pred
|
|
|
|
|
|
|
|
|
for idx, face in enumerate(face_data):
|
|
|
bbox = face['bbox']
|
|
|
pred = face['predictions']
|
|
|
x, y, w, h = bbox
|
|
|
|
|
|
|
|
|
font_size_ratio = 0.08
|
|
|
min_font_size = 12
|
|
|
max_font_size = 48
|
|
|
adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size))
|
|
|
try:
|
|
|
font = ImageFont.load_default(size=adaptive_font_size)
|
|
|
except IOError:
|
|
|
font = ImageFont.load_default()
|
|
|
|
|
|
|
|
|
draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2)
|
|
|
|
|
|
|
|
|
lines_to_draw = []
|
|
|
|
|
|
|
|
|
age_label = pred['age']['predicted_class']
|
|
|
age_conf = pred['age']['predicted_confidence']
|
|
|
lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)")
|
|
|
|
|
|
|
|
|
gen_label = pred['gender']['predicted_class']
|
|
|
gen_conf = pred['gender']['predicted_confidence']
|
|
|
lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)")
|
|
|
|
|
|
|
|
|
emo_label = pred['emotion']['predicted_class']
|
|
|
emo_conf = pred['emotion']['predicted_confidence']
|
|
|
lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)")
|
|
|
|
|
|
|
|
|
|
|
|
line_spacing = 10
|
|
|
total_text_height = 0
|
|
|
for line in lines_to_draw:
|
|
|
_left, top, _right, bottom = draw.textbbox((0, 0), line, font=font)
|
|
|
total_text_height += (bottom - top) + line_spacing
|
|
|
|
|
|
|
|
|
if y - total_text_height > 0:
|
|
|
text_y = y - line_spacing
|
|
|
for line in reversed(lines_to_draw):
|
|
|
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls")
|
|
|
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
|
|
draw.text((x, text_y), line, font=font, fill="white", anchor="ls")
|
|
|
text_y = top - line_spacing
|
|
|
else:
|
|
|
text_y = y + h + line_spacing
|
|
|
for line in lines_to_draw:
|
|
|
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt")
|
|
|
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
|
|
draw.text((x, text_y), line, font=font, fill="white", anchor="lt")
|
|
|
text_y = bottom + line_spacing
|
|
|
|
|
|
|
|
|
|
|
|
def pil_to_base64(img_pil):
|
|
|
buffered = BytesIO()
|
|
|
img_pil.save(buffered, format="JPEG")
|
|
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
|
|
return f"data:image/jpeg;base64,{img_str}"
|
|
|
|
|
|
results_html = f"""
|
|
|
<style>
|
|
|
:root {{
|
|
|
--primary-color: #4f46e5;
|
|
|
--success-color: #10b981;
|
|
|
|
|
|
--text-primary: var(--body-text-color);
|
|
|
--text-secondary: var(--body-text-color-subdued);
|
|
|
--background-dark: var(--background-fill-primary);
|
|
|
--background-darker: var(--background-fill-secondary);
|
|
|
--border-color: var(--border-color-primary);
|
|
|
}}
|
|
|
.results-container {{
|
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
|
background: var(--background-darker);
|
|
|
padding: 20px;
|
|
|
border-radius: 12px;
|
|
|
color: var(--text-primary);
|
|
|
}}
|
|
|
.results-container h2 {{
|
|
|
color: var(--text-primary);
|
|
|
margin-bottom: 20px;
|
|
|
}}
|
|
|
.face-count {{
|
|
|
display: inline-block;
|
|
|
background: var(--primary-color);
|
|
|
color: white;
|
|
|
padding: 4px 12px;
|
|
|
border-radius: 20px;
|
|
|
font-size: 0.9rem;
|
|
|
font-weight: 500;
|
|
|
margin-left: 8px;
|
|
|
}}
|
|
|
.face-card {{
|
|
|
background: var(--background-dark);
|
|
|
border-radius: 8px;
|
|
|
padding: 20px;
|
|
|
margin-top: 15px;
|
|
|
border: 1px solid var(--border-color);
|
|
|
display: flex;
|
|
|
gap: 20px;
|
|
|
align-items: flex-start;
|
|
|
}}
|
|
|
.face-header {{
|
|
|
font-size: 1rem;
|
|
|
font-weight: 600;
|
|
|
margin-bottom: 20px;
|
|
|
color: var(--text-primary);
|
|
|
}}
|
|
|
.face-image-left {{
|
|
|
flex-shrink: 0;
|
|
|
width: 336px;
|
|
|
height: 336px;
|
|
|
background: var(--background-darker);
|
|
|
border-radius: 8px;
|
|
|
overflow: hidden;
|
|
|
border: 1px solid var(--border-color);
|
|
|
}}
|
|
|
.face-image-left img {{
|
|
|
width: 100%;
|
|
|
height: 100%;
|
|
|
object-fit: cover;
|
|
|
}}
|
|
|
.face-predictions-right {{
|
|
|
flex: 1;
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
gap: 10px;
|
|
|
}}
|
|
|
.predictions-horizontal {{
|
|
|
display: flex;
|
|
|
flex-direction: row;
|
|
|
gap: 30px;
|
|
|
justify-content: space-between;
|
|
|
}}
|
|
|
.prediction-section {{
|
|
|
flex: 1;
|
|
|
min-width: 0;
|
|
|
}}
|
|
|
.prediction-category-label {{
|
|
|
font-size: 0.8rem;
|
|
|
font-weight: 700;
|
|
|
text-transform: uppercase;
|
|
|
letter-spacing: 0.5px;
|
|
|
color: var(--primary-color);
|
|
|
margin-bottom: 8px;
|
|
|
border-bottom: 2px solid var(--primary-color);
|
|
|
padding-bottom: 4px;
|
|
|
}}
|
|
|
.probabilities-list {{
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
gap: 6px;
|
|
|
}}
|
|
|
.probability-item {{
|
|
|
display: grid;
|
|
|
grid-template-columns: 70px 1fr 55px;
|
|
|
align-items: center;
|
|
|
gap: 8px;
|
|
|
padding: 4px 6px;
|
|
|
border-radius: 4px;
|
|
|
}}
|
|
|
.probability-item.predicted {{
|
|
|
background: rgba(79, 70, 229, 0.2);
|
|
|
border-left: 3px solid var(--primary-color);
|
|
|
padding-left: 8px;
|
|
|
}}
|
|
|
.prob-class {{
|
|
|
font-size: 0.8rem;
|
|
|
font-weight: 600;
|
|
|
color: var(--text-primary);
|
|
|
word-wrap: break-word; /* Ensure long class names wrap */
|
|
|
}}
|
|
|
.probability-item.predicted .prob-class {{
|
|
|
color: var(--primary-color);
|
|
|
font-weight: 700;
|
|
|
}}
|
|
|
.prob-bar-container {{
|
|
|
height: 6px;
|
|
|
background: var(--border-color);
|
|
|
border-radius: 3px;
|
|
|
overflow: hidden;
|
|
|
}}
|
|
|
.prob-bar {{
|
|
|
height: 100%;
|
|
|
background: linear-gradient(90deg, var(--primary-color), var(--success-color));
|
|
|
border-radius: 3px;
|
|
|
transition: width 0.6s ease;
|
|
|
}}
|
|
|
.probability-item.predicted .prob-bar {{
|
|
|
background: var(--primary-color);
|
|
|
}}
|
|
|
.prob-percentage {{
|
|
|
font-size: 0.75rem;
|
|
|
font-weight: 500;
|
|
|
color: var(--text-secondary);
|
|
|
text-align: right;
|
|
|
}}
|
|
|
.probability-item.predicted .prob-percentage {{
|
|
|
color: var(--primary-color);
|
|
|
font-weight: 700;
|
|
|
}}
|
|
|
@media (max-width: 1200px) {{
|
|
|
.predictions-horizontal {{
|
|
|
flex-direction: column;
|
|
|
gap: 15px;
|
|
|
}}
|
|
|
}}
|
|
|
@media (max-width: 900px) {{
|
|
|
.face-card {{
|
|
|
flex-direction: column;
|
|
|
}}
|
|
|
.face-image-left {{
|
|
|
width: 100%;
|
|
|
max-width: 336px;
|
|
|
margin: 0 auto;
|
|
|
}}
|
|
|
.probability-item {{
|
|
|
grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
|
|
|
}}
|
|
|
.prob-class {{
|
|
|
font-size: 0.75rem;
|
|
|
}}
|
|
|
}}
|
|
|
</style>
|
|
|
|
|
|
<div class='results-container'>
|
|
|
<h2 style='margin-top: 0;'>Classification Results <span class='face-count'>{len(face_data)} face(s)</span></h2>
|
|
|
"""
|
|
|
|
|
|
for idx, face in enumerate(face_data):
|
|
|
pred = face['predictions']
|
|
|
face_img_base64 = pil_to_base64(face['crop_image'])
|
|
|
age = get_centroid_weighted_age(pred['age']['all_probabilities'])
|
|
|
results_html += f"""
|
|
|
<div class='face-card'>
|
|
|
<div class='face-image-left'>
|
|
|
<img src='{face_img_base64}' alt='Face {idx+1}'>
|
|
|
</div>
|
|
|
<div class='face-predictions-right'>
|
|
|
<div class='face-header'>Face {idx+1} - Detection Confidence: {face['detection_confidence']:.1%} - Centroid Age: {int(age)}</div>
|
|
|
<div class='predictions-horizontal'>
|
|
|
<div class='prediction-section'>
|
|
|
<div class='prediction-category-label'>Age</div>
|
|
|
<div class='probabilities-list'>
|
|
|
"""
|
|
|
for age_class in CLASSES[0]:
|
|
|
prob = pred['age']['all_probabilities'][age_class]
|
|
|
is_predicted = (age_class == pred['age']['predicted_class'])
|
|
|
predicted_class = 'predicted' if is_predicted else ''
|
|
|
results_html += f"""
|
|
|
<div class='probability-item {predicted_class}'>
|
|
|
<span class='prob-class'>{age_class}</span>
|
|
|
<div class='prob-bar-container'>
|
|
|
<div class='prob-bar' style='width: {prob*100}%'></div>
|
|
|
</div>
|
|
|
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
|
|
</div>
|
|
|
"""
|
|
|
results_html += f"""
|
|
|
</div>
|
|
|
</div>
|
|
|
<div class='prediction-section'>
|
|
|
<div class='prediction-category-label'>Gender</div>
|
|
|
<div class='probabilities-list'>
|
|
|
"""
|
|
|
for gender_class in CLASSES[1]:
|
|
|
prob = pred['gender']['all_probabilities'][gender_class]
|
|
|
is_predicted = (gender_class == pred['gender']['predicted_class'])
|
|
|
predicted_class = 'predicted' if is_predicted else ''
|
|
|
results_html += f"""
|
|
|
<div class='probability-item {predicted_class}'>
|
|
|
<span class='prob-class'>{gender_class}</span>
|
|
|
<div class='prob-bar-container'>
|
|
|
<div class='prob-bar' style='width: {prob*100}%'></div>
|
|
|
</div>
|
|
|
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
|
|
</div>
|
|
|
"""
|
|
|
results_html += """
|
|
|
</div>
|
|
|
</div>
|
|
|
<div class='prediction-section'>
|
|
|
<div class='prediction-category-label'>Emotion</div>
|
|
|
<div class='probabilities-list'>
|
|
|
"""
|
|
|
for emotion_class in CLASSES[2]:
|
|
|
prob = pred['emotion']['all_probabilities'][emotion_class]
|
|
|
is_predicted = (emotion_class == pred['emotion']['predicted_class'])
|
|
|
predicted_class = 'predicted' if is_predicted else ''
|
|
|
results_html += f"""
|
|
|
<div class='probability-item {predicted_class}'>
|
|
|
<span class='prob-class'>{emotion_class}</span>
|
|
|
<div class='prob-bar-container'>
|
|
|
<div class='prob-bar' style='width: {prob*100}%'></div>
|
|
|
</div>
|
|
|
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
|
|
</div>
|
|
|
"""
|
|
|
results_html += """
|
|
|
</div>
|
|
|
</div>
|
|
|
</div>
|
|
|
</div>
|
|
|
</div>
|
|
|
"""
|
|
|
results_html += "</div>"
|
|
|
|
|
|
|
|
|
return img_pil_annotated, results_html
|
|
|
|
|
|
except Exception as e:
|
|
|
traceback.print_exc()
|
|
|
return image, f"<p style='color: red;'>Error processing image: {str(e)}</p>"
|
|
|
|
|
|
def create_interface(checkpoint_list, default_checkpoint, initial_status):
|
|
|
"""Create and configure the Gradio interface."""
|
|
|
|
|
|
custom_css = """
|
|
|
.gradio-container {
|
|
|
font-family: 'Arial', sans-serif;
|
|
|
}
|
|
|
.output-html {
|
|
|
max-height: none !important;
|
|
|
overflow-y: auto;
|
|
|
}
|
|
|
:root {
|
|
|
--primary-color: #4f46e5;
|
|
|
--success-color: #10b981;
|
|
|
|
|
|
--text-primary: var(--body-text-color);
|
|
|
--text-secondary: var(--body-text-color-subdued);
|
|
|
--background-dark: var(--background-fill-primary);
|
|
|
--background-darker: var(--background-fill-secondary);
|
|
|
--border-color: var(--border-color-primary);
|
|
|
}
|
|
|
.results-container {
|
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
|
|
background: var(--background-darker);
|
|
|
padding: 20px;
|
|
|
border-radius: 12px;
|
|
|
color: var(--text-primary);
|
|
|
}
|
|
|
.results-container h2 {
|
|
|
color: var(--text-primary);
|
|
|
margin-bottom: 20px;
|
|
|
}
|
|
|
.face-count {
|
|
|
display: inline-block;
|
|
|
background: var(--primary-color);
|
|
|
color: white;
|
|
|
padding: 4px 12px;
|
|
|
border-radius: 20px;
|
|
|
font-size: 0.9rem;
|
|
|
font-weight: 500;
|
|
|
margin-left: 8px;
|
|
|
}
|
|
|
.face-card {
|
|
|
background: var(--background-dark);
|
|
|
border-radius: 8px;
|
|
|
padding: 20px;
|
|
|
margin-top: 15px;
|
|
|
border: 1px solid var(--border-color);
|
|
|
display: flex;
|
|
|
gap: 20px;
|
|
|
align-items: flex-start;
|
|
|
}
|
|
|
.face-header {
|
|
|
font-size: 1rem;
|
|
|
font-weight: 600;
|
|
|
margin-bottom: 20px;
|
|
|
color: var(--text-primary);
|
|
|
}
|
|
|
.face-image-left {
|
|
|
flex-shrink: 0;
|
|
|
width: 336px;
|
|
|
height: 336px;
|
|
|
background: var(--background-darker);
|
|
|
border-radius: 8px;
|
|
|
overflow: hidden;
|
|
|
border: 1px solid var(--border-color);
|
|
|
}
|
|
|
.face-image-left img {
|
|
|
width: 100%;
|
|
|
height: 100%;
|
|
|
object-fit: cover;
|
|
|
}
|
|
|
.face-predictions-right {
|
|
|
flex: 1;
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
gap: 10px;
|
|
|
}
|
|
|
.predictions-horizontal {
|
|
|
display: flex;
|
|
|
flex-direction: row;
|
|
|
gap: 30px;
|
|
|
justify-content: space-between;
|
|
|
}
|
|
|
.prediction-section {
|
|
|
flex: 1;
|
|
|
min-width: 0;
|
|
|
}
|
|
|
.prediction-category-label {
|
|
|
font-size: 0.8rem;
|
|
|
font-weight: 700;
|
|
|
text-transform: uppercase;
|
|
|
letter-spacing: 0.5px;
|
|
|
color: var(--primary-color);
|
|
|
margin-bottom: 8px;
|
|
|
border-bottom: 2px solid var(--primary-color);
|
|
|
padding-bottom: 4px;
|
|
|
}
|
|
|
.probabilities-list {
|
|
|
display: flex;
|
|
|
flex-direction: column;
|
|
|
gap: 6px;
|
|
|
}
|
|
|
.probability-item {
|
|
|
display: grid;
|
|
|
grid-template-columns: 70px 1fr 55px;
|
|
|
align-items: center;
|
|
|
gap: 8px;
|
|
|
padding: 4px 6px;
|
|
|
border-radius: 4px;
|
|
|
}
|
|
|
.probability-item.predicted {
|
|
|
background: rgba(79, 70, 229, 0.2);
|
|
|
border-left: 3px solid var(--primary-color);
|
|
|
padding-left: 8px;
|
|
|
}
|
|
|
.prob-class {
|
|
|
font-size: 0.8rem;
|
|
|
font-weight: 600;
|
|
|
color: var(--text-primary);
|
|
|
word-wrap: break-word; /* Ensure long class names wrap */
|
|
|
}
|
|
|
.probability-item.predicted .prob-class {
|
|
|
color: var(--primary-color);
|
|
|
font-weight: 700;
|
|
|
}
|
|
|
.prob-bar-container {
|
|
|
height: 6px;
|
|
|
background: var(--border-color);
|
|
|
border-radius: 3px;
|
|
|
overflow: hidden;
|
|
|
}
|
|
|
.prob-bar {
|
|
|
height: 100%;
|
|
|
background: linear-gradient(90deg, var(--primary-color), var(--success-color));
|
|
|
border-radius: 3px;
|
|
|
transition: width 0.6s ease;
|
|
|
}
|
|
|
.probability-item.predicted .prob-bar {
|
|
|
background: var(--primary-color);
|
|
|
}
|
|
|
.prob-percentage {
|
|
|
font-size: 0.75rem;
|
|
|
font-weight: 500;
|
|
|
color: var(--text-secondary);
|
|
|
text-align: right;
|
|
|
}
|
|
|
.probability-item.predicted .prob-percentage {
|
|
|
color: var(--primary-color);
|
|
|
font-weight: 700;
|
|
|
}
|
|
|
@media (max-width: 1200px) {
|
|
|
.predictions-horizontal {
|
|
|
flex-direction: column;
|
|
|
gap: 15px;
|
|
|
}
|
|
|
}
|
|
|
@media (max-width: 900px) {
|
|
|
.face-card {
|
|
|
flex-direction: column;
|
|
|
}
|
|
|
.face-image-left {
|
|
|
width: 100%;
|
|
|
max-width: 336px;
|
|
|
margin: 0 auto;
|
|
|
}
|
|
|
.probability-item {
|
|
|
grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
|
|
|
}
|
|
|
.prob-class {
|
|
|
font-size: 0.75rem;
|
|
|
}
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="Face Classification System", theme=gr.themes.Default()) as demo:
|
|
|
|
|
|
with gr.Row():
|
|
|
gr.Markdown("# Face Classification System")
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=3):
|
|
|
checkpoint_dropdown = gr.Dropdown(
|
|
|
label="Select Model Checkpoint",
|
|
|
choices=checkpoint_list,
|
|
|
value=default_checkpoint,
|
|
|
)
|
|
|
with gr.Column(scale=2):
|
|
|
model_status_text = gr.Textbox(
|
|
|
label="Model Status",
|
|
|
value=initial_status,
|
|
|
interactive=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
gr.Markdown("""
|
|
|
### Features
|
|
|
- **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average
|
|
|
- **Gender Classification**: M/F
|
|
|
- **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral)
|
|
|
- **Automatic Face Detection**: Detects and analyzes multiple faces
|
|
|
- **Detailed Probability Distributions**: View confidence for all classes
|
|
|
""")
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
gr.Markdown("""
|
|
|
### Instructions
|
|
|
1. (Optional) Select a model checkpoint from the dropdown.
|
|
|
2. Upload an image or capture from webcam (or select an example below)
|
|
|
3. Click "Classify Image"
|
|
|
4. View detected faces with age, gender, and emotion predictions below
|
|
|
\n
|
|
|
Demo video of usage of this space: https://youtu.be/V6-9QTf1xaQ
|
|
|
""")
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
input_image = gr.Image(
|
|
|
label="Upload Image",
|
|
|
type="pil",
|
|
|
sources=["upload", "webcam"],
|
|
|
height=400
|
|
|
)
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
output_image = gr.Image(
|
|
|
label="Annotated Image",
|
|
|
type="pil",
|
|
|
height=400
|
|
|
)
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
analyze_btn = gr.Button(
|
|
|
"Classify Image",
|
|
|
variant="primary",
|
|
|
size="lg"
|
|
|
)
|
|
|
|
|
|
|
|
|
example_dir = "example"
|
|
|
example_images = []
|
|
|
if os.path.exists(example_dir):
|
|
|
try:
|
|
|
example_images = [
|
|
|
os.path.join(example_dir, f)
|
|
|
for f in sorted(os.listdir(example_dir))
|
|
|
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
|
|
]
|
|
|
except Exception as e:
|
|
|
print(f"Error reading example images from {example_dir}: {e}")
|
|
|
|
|
|
if example_images:
|
|
|
gr.Markdown("### ๐ธ Try with example images")
|
|
|
gr.Examples(
|
|
|
examples=example_images,
|
|
|
inputs=input_image,
|
|
|
cache_examples=False
|
|
|
)
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
output_html = gr.HTML(
|
|
|
label="Classification Results",
|
|
|
elem_classes="output-html"
|
|
|
)
|
|
|
|
|
|
|
|
|
analyze_btn.click(
|
|
|
fn=process_image,
|
|
|
inputs=[input_image, checkpoint_dropdown],
|
|
|
outputs=[output_image, output_html]
|
|
|
)
|
|
|
|
|
|
checkpoint_dropdown.change(
|
|
|
fn=load_model_and_update_status,
|
|
|
inputs=[checkpoint_dropdown],
|
|
|
outputs=[model_status_text]
|
|
|
)
|
|
|
|
|
|
|
|
|
return demo
|
|
|
|
|
|
|
|
|
|
|
|
print("="*60)
|
|
|
print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
print(f"Downloading model weights from {MODEL_REPO_ID} to {CHECKPOINTS_DIR}...")
|
|
|
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
|
|
try:
|
|
|
snapshot_download(
|
|
|
repo_id=MODEL_REPO_ID,
|
|
|
local_dir=CHECKPOINTS_DIR,
|
|
|
allow_patterns=["*.pt", "*.pth"],
|
|
|
local_dir_use_symlinks=False,
|
|
|
)
|
|
|
print("Model download complete.")
|
|
|
except Exception as e:
|
|
|
print(f"CRITICAL: Failed to download models from Hub. {e}")
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
|
checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
|
|
|
|
|
|
if not checkpoint_list:
|
|
|
print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.")
|
|
|
else:
|
|
|
print(f"Found checkpoints: {len(checkpoint_list)} file(s).")
|
|
|
print(f"Default checkpoint: {default_checkpoint}")
|
|
|
|
|
|
|
|
|
initial_status_msg = "No default model found. Please select one."
|
|
|
if default_checkpoint:
|
|
|
print(f"\nInitializing default model: {default_checkpoint}")
|
|
|
|
|
|
|
|
|
initial_status_msg = load_model_and_update_status(default_checkpoint)
|
|
|
print(initial_status_msg)
|
|
|
else:
|
|
|
print("Warning: No default model to load.")
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating Gradio interface...")
|
|
|
demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg)
|
|
|
print("โ Interface created successfully!")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
|
|
|
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
|
|
|
help="Path to the checkpoint directory (will be populated from HF Hub)")
|
|
|
parser.add_argument("--detection_confidence", type=float, default=0.5,
|
|
|
help="Confidence threshold for face detection")
|
|
|
parser.add_argument("--port", type=int, default=7860,
|
|
|
help="Port to run the Gradio app")
|
|
|
parser.add_argument("--share", action="store_true",
|
|
|
help="Create a public share link")
|
|
|
parser.add_argument("--server_name", type=str, default="0.0.0.0",
|
|
|
help="Server name/IP to bind to")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
CHECKPOINTS_DIR = args.ckpt_dir
|
|
|
|
|
|
print(f"\nLaunching server on {args.server_name}:{args.port}")
|
|
|
print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
|
|
|
print("="*60)
|
|
|
|
|
|
demo.launch(
|
|
|
share=args.share,
|
|
|
server_name=args.server_name,
|
|
|
server_port=args.port,
|
|
|
show_error=True,
|
|
|
) |