FaR-FT-PE / app.py
Antuke
fix
a20f097
"""
VLM Soft Biometrics - Gradio Interface
A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models.
"""
import os
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import base64
from io import BytesIO
import traceback
from huggingface_hub import snapshot_download
from utils.face_detector import FaceDetector
# Class definitions
from src.model import MTLModel
from utils.commons import get_backbone_pe
from utils.task_config import Task
TASKS = [
Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None),
Task(name='Gender', class_labels=["Male", "Female"], criterion=None),
Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None)
]
CLASSES = [
["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"],
["M", "F"],
["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"]
]
# Global variables for model and detector
model = None
transform = None
detector = None
device = None
current_ckpt_dir = None
CHECKPOINTS_DIR = './checkpoints/'
MODEL_REPO_ID = "Antuke/FaR-FT-PE"
def scan_checkpoints(ckpt_dir):
"""Scans a directory for .pt or .pth files."""
if not os.path.exists(ckpt_dir):
print(f"Warning: Checkpoint directory not found: {ckpt_dir}")
return [], None
try:
ckpt_files = [
os.path.join(ckpt_dir, f)
for f in sorted(os.listdir(ckpt_dir))
if f.endswith(('.pt', '.pth'))
]
except Exception as e:
print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
return [], None
choices_list = [(os.path.basename(f), f) for f in ckpt_files]
default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt')
if default_ckpt_path in ckpt_files:
return choices_list, default_ckpt_path
elif ckpt_files:
return choices_list, ckpt_files[0]
else:
print(f"No checkpoints found in {ckpt_dir}")
return [], None
def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
"""Load and configure model."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
model = MTLModel(backbone,device=device,tasks=TASKS,use_lora=True,use_deep_head=True,
use_mtl_lora=('mtlora' in ckpt_dir),
)
print(f'loading from {ckpt_dir}')
model.load_model(filepath=ckpt_dir,map_location=device)
return model,transform
def load_model_and_update_status(model_filepath):
"""Wrapper function to load a model """
global model, current_ckpt_dir
if model_filepath is None or model_filepath == "":
return "No checkpoint selected."
# Check if this model filepath is already loaded
if model is not None and model_filepath == current_ckpt_dir:
status = f"Model already loaded: {os.path.basename(model_filepath)}"
print(status)
return status
gr.Info(f"Loading model: {os.path.basename(model_filepath)}...")
try:
init_model(ckpt_dir=model_filepath, detection_confidence=0.5)
current_ckpt_dir = model_filepath # Set global path on successful load
status = f"Successfully loaded: {os.path.basename(model_filepath)}"
gr.Info("Model loaded successfully!")
print(status)
return status
except Exception as e:
traceback.print_exc()
status = f"Failed to load {os.path.basename(model_filepath)}: {e}"
gr.Info(f"Error: {status}")
print(f"ERROR: {status}")
return status
def predict(model, image):
"""Make predictions for age, gender, and emotion."""
with torch.no_grad():
results = model(image)
age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion']
# Get probabilities using softmax
age_probs = torch.softmax(age_logits, dim=-1)
gender_probs = torch.softmax(gender_logits, dim=-1)
emotion_probs = torch.softmax(emotion_logits, dim=-1)
ages = torch.argmax(age_logits, dim=-1).cpu().tolist()
genders = torch.argmax(gender_logits, dim=-1).cpu().tolist()
emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist()
results = []
for i in range(len(ages)):
# Get all probabilities for each class
age_all_probs = {
CLASSES[0][j]: float(age_probs[i][j].cpu().detach())
for j in range(len(CLASSES[0]))
}
gender_all_probs = {
CLASSES[1][j]: float(gender_probs[i][j].cpu().detach())
for j in range(len(CLASSES[1]))
}
emotion_all_probs = {
CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach())
for j in range(len(CLASSES[2]))
}
results.append({
'age': {
'predicted_class': CLASSES[0][ages[i]],
'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()),
'all_probabilities': age_all_probs
},
'gender': {
'predicted_class': CLASSES[1][genders[i]],
'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()),
'all_probabilities': gender_all_probs
},
'emotion': {
'predicted_class': CLASSES[2][emotions[i]],
'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()),
'all_probabilities': emotion_all_probs
}
})
return results
def get_centroid_weighted_age(probs):
"""
Using centroids of age group we calculate an age regression number
using an average weight based on predicted probability distribution
"""
probs = list(probs.values())
centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80]
age = 0
for i,p in enumerate(probs):
age += p * centroids[i]
return age
def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
"""Initialize model and detector."""
global model, transform, detector, device
print(f"\n{'='*60}")
print(f"INITIALIZING MODEL: {ckpt_dir}")
print(f"{'='*60}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not os.path.exists(ckpt_dir):
error_msg = f"Model weights not found: {ckpt_dir}."
print(f"ERROR: {error_msg}")
raise FileNotFoundError(error_msg)
print(f"Model weights found: {ckpt_dir}")
# Load the perception encoder
model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
model.eval()
print(device)
model.to(device)
detector = FaceDetector(confidence_threshold=detection_confidence)
print("โœ“ Model and detector initialized successfully")
print(f"{'='*60}\n")
def process_image(image, selected_checkpoint_path):
"""
Process an uploaded image and return predictions with annotated image.
Args:
image: PIL Image or numpy array
selected_checkpoint_path: The path from the checkpoint dropdown
Returns:
tuple: (annotated_image, results_html)
"""
if image is None:
return None, "<p style='color: red;'>Please upload an image</p>"
# Ensure model is initialized
# this check ensures the selected model is loaded.
if model is None or selected_checkpoint_path != current_ckpt_dir:
print(f"Model mismatch or not loaded. Selected: {selected_checkpoint_path}, Current: {current_ckpt_dir}")
status = load_model_and_update_status(selected_checkpoint_path)
if "Failed" in status or "Error" in status:
return image, f"<p style'color: red;'>Model Error: {status}</p>"
try:
# Convert PIL to OpenCV format (BGR) for the detector
if isinstance(image, Image.Image):
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Create a PIL copy to draw annotations on
img_pil_annotated = image.copy()
draw = ImageDraw.Draw(img_pil_annotated)
faces = detector.detect(img_cv, pad_rect=True)
if faces is None or len(faces) == 0:
return image, "<p style='color: orange;'>No faces detected in the image</p>"
# --- Process detected faces ---
crops_pil = []
face_data = []
for idx, (crop, confidence, bbox) in enumerate(faces):
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
crop_pil = Image.fromarray(crop_rgb)
crops_pil.append(crop_pil)
# Resize crop to 336x336 for display, to match model input size
crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS)
face_data.append({
'bbox': bbox,
'detection_confidence': float(confidence),
'crop_image': crop_resized
})
# --- Batch transform and predict ---
crop_tensors = [transform(crop_pil) for crop_pil in crops_pil]
batch_tensor = torch.stack(crop_tensors).to(device)
predictions = predict(model, batch_tensor)
# Combine face data with predictions
for face, pred in zip(face_data, predictions):
face['predictions'] = pred
# --- Create annotated image (using PIL) ---
for idx, face in enumerate(face_data):
bbox = face['bbox']
pred = face['predictions']
x, y, w, h = bbox
# --- Calculate Adaptive Font ---
font_size_ratio = 0.08
min_font_size = 12
max_font_size = 48
adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size))
try:
font = ImageFont.load_default(size=adaptive_font_size)
except IOError:
font = ImageFont.load_default()
# --- Draw Bounding Box ---
draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2)
# --- Prepare Text Lines ---
lines_to_draw = []
# Age
age_label = pred['age']['predicted_class']
age_conf = pred['age']['predicted_confidence']
lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)")
# Gender
gen_label = pred['gender']['predicted_class']
gen_conf = pred['gender']['predicted_confidence']
lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)")
# Emotion
emo_label = pred['emotion']['predicted_class']
emo_conf = pred['emotion']['predicted_confidence']
lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)")
# --- Calculate total height of the text block ---
line_spacing = 10
total_text_height = 0
for line in lines_to_draw:
_left, top, _right, bottom = draw.textbbox((0, 0), line, font=font)
total_text_height += (bottom - top) + line_spacing
# --- Place text ABOVE or BELOW the box ---
if y - total_text_height > 0:
text_y = y - line_spacing
for line in reversed(lines_to_draw):
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls") # anchor left-baseline
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
draw.text((x, text_y), line, font=font, fill="white", anchor="ls")
text_y = top - line_spacing # Move y-position up for the next line
else:
text_y = y + h + line_spacing
for line in lines_to_draw:
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt")
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
draw.text((x, text_y), line, font=font, fill="white", anchor="lt")
text_y = bottom + line_spacing
# Helper function to convert PIL image to base64
def pil_to_base64(img_pil):
buffered = BytesIO()
img_pil.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
results_html = f"""
<style>
:root {{
--primary-color: #4f46e5;
--success-color: #10b981;
--text-primary: var(--body-text-color);
--text-secondary: var(--body-text-color-subdued);
--background-dark: var(--background-fill-primary);
--background-darker: var(--background-fill-secondary);
--border-color: var(--border-color-primary);
}}
.results-container {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: var(--background-darker);
padding: 20px;
border-radius: 12px;
color: var(--text-primary);
}}
.results-container h2 {{
color: var(--text-primary);
margin-bottom: 20px;
}}
.face-count {{
display: inline-block;
background: var(--primary-color);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.9rem;
font-weight: 500;
margin-left: 8px;
}}
.face-card {{
background: var(--background-dark);
border-radius: 8px;
padding: 20px;
margin-top: 15px;
border: 1px solid var(--border-color);
display: flex;
gap: 20px;
align-items: flex-start;
}}
.face-header {{
font-size: 1rem;
font-weight: 600;
margin-bottom: 20px;
color: var(--text-primary);
}}
.face-image-left {{
flex-shrink: 0;
width: 336px;
height: 336px;
background: var(--background-darker);
border-radius: 8px;
overflow: hidden;
border: 1px solid var(--border-color);
}}
.face-image-left img {{
width: 100%;
height: 100%;
object-fit: cover;
}}
.face-predictions-right {{
flex: 1;
display: flex;
flex-direction: column;
gap: 10px;
}}
.predictions-horizontal {{
display: flex;
flex-direction: row;
gap: 30px;
justify-content: space-between;
}}
.prediction-section {{
flex: 1;
min-width: 0;
}}
.prediction-category-label {{
font-size: 0.8rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.5px;
color: var(--primary-color);
margin-bottom: 8px;
border-bottom: 2px solid var(--primary-color);
padding-bottom: 4px;
}}
.probabilities-list {{
display: flex;
flex-direction: column;
gap: 6px;
}}
.probability-item {{
display: grid;
grid-template-columns: 70px 1fr 55px;
align-items: center;
gap: 8px;
padding: 4px 6px;
border-radius: 4px;
}}
.probability-item.predicted {{
background: rgba(79, 70, 229, 0.2);
border-left: 3px solid var(--primary-color);
padding-left: 8px;
}}
.prob-class {{
font-size: 0.8rem;
font-weight: 600;
color: var(--text-primary);
word-wrap: break-word; /* Ensure long class names wrap */
}}
.probability-item.predicted .prob-class {{
color: var(--primary-color);
font-weight: 700;
}}
.prob-bar-container {{
height: 6px;
background: var(--border-color);
border-radius: 3px;
overflow: hidden;
}}
.prob-bar {{
height: 100%;
background: linear-gradient(90deg, var(--primary-color), var(--success-color));
border-radius: 3px;
transition: width 0.6s ease;
}}
.probability-item.predicted .prob-bar {{
background: var(--primary-color);
}}
.prob-percentage {{
font-size: 0.75rem;
font-weight: 500;
color: var(--text-secondary);
text-align: right;
}}
.probability-item.predicted .prob-percentage {{
color: var(--primary-color);
font-weight: 700;
}}
@media (max-width: 1200px) {{
.predictions-horizontal {{
flex-direction: column;
gap: 15px;
}}
}}
@media (max-width: 900px) {{
.face-card {{
flex-direction: column;
}}
.face-image-left {{
width: 100%;
max-width: 336px;
margin: 0 auto;
}}
.probability-item {{
grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
}}
.prob-class {{
font-size: 0.75rem;
}}
}}
</style>
<div class='results-container'>
<h2 style='margin-top: 0;'>Classification Results <span class='face-count'>{len(face_data)} face(s)</span></h2>
"""
for idx, face in enumerate(face_data):
pred = face['predictions']
face_img_base64 = pil_to_base64(face['crop_image'])
age = get_centroid_weighted_age(pred['age']['all_probabilities'])
results_html += f"""
<div class='face-card'>
<div class='face-image-left'>
<img src='{face_img_base64}' alt='Face {idx+1}'>
</div>
<div class='face-predictions-right'>
<div class='face-header'>Face {idx+1} - Detection Confidence: {face['detection_confidence']:.1%} - Centroid Age: {int(age)}</div>
<div class='predictions-horizontal'>
<div class='prediction-section'>
<div class='prediction-category-label'>Age</div>
<div class='probabilities-list'>
"""
for age_class in CLASSES[0]:
prob = pred['age']['all_probabilities'][age_class]
is_predicted = (age_class == pred['age']['predicted_class'])
predicted_class = 'predicted' if is_predicted else ''
results_html += f"""
<div class='probability-item {predicted_class}'>
<span class='prob-class'>{age_class}</span>
<div class='prob-bar-container'>
<div class='prob-bar' style='width: {prob*100}%'></div>
</div>
<span class='prob-percentage'>{prob*100:.1f}%</span>
</div>
"""
results_html += f"""
</div>
</div>
<div class='prediction-section'>
<div class='prediction-category-label'>Gender</div>
<div class='probabilities-list'>
"""
for gender_class in CLASSES[1]:
prob = pred['gender']['all_probabilities'][gender_class]
is_predicted = (gender_class == pred['gender']['predicted_class'])
predicted_class = 'predicted' if is_predicted else ''
results_html += f"""
<div class='probability-item {predicted_class}'>
<span class='prob-class'>{gender_class}</span>
<div class='prob-bar-container'>
<div class='prob-bar' style='width: {prob*100}%'></div>
</div>
<span class='prob-percentage'>{prob*100:.1f}%</span>
</div>
"""
results_html += """
</div>
</div>
<div class='prediction-section'>
<div class='prediction-category-label'>Emotion</div>
<div class='probabilities-list'>
"""
for emotion_class in CLASSES[2]:
prob = pred['emotion']['all_probabilities'][emotion_class]
is_predicted = (emotion_class == pred['emotion']['predicted_class'])
predicted_class = 'predicted' if is_predicted else ''
results_html += f"""
<div class='probability-item {predicted_class}'>
<span class='prob-class'>{emotion_class}</span>
<div class='prob-bar-container'>
<div class='prob-bar' style='width: {prob*100}%'></div>
</div>
<span class='prob-percentage'>{prob*100:.1f}%</span>
</div>
"""
results_html += """
</div>
</div>
</div>
</div>
</div>
"""
results_html += "</div>"
# --- Return the annotated PIL image and HTML ---
return img_pil_annotated, results_html
except Exception as e:
traceback.print_exc()
return image, f"<p style='color: red;'>Error processing image: {str(e)}</p>"
def create_interface(checkpoint_list, default_checkpoint, initial_status):
"""Create and configure the Gradio interface."""
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.output-html {
max-height: none !important;
overflow-y: auto;
}
:root {
--primary-color: #4f46e5;
--success-color: #10b981;
--text-primary: var(--body-text-color);
--text-secondary: var(--body-text-color-subdued);
--background-dark: var(--background-fill-primary);
--background-darker: var(--background-fill-secondary);
--border-color: var(--border-color-primary);
}
.results-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: var(--background-darker);
padding: 20px;
border-radius: 12px;
color: var(--text-primary);
}
.results-container h2 {
color: var(--text-primary);
margin-bottom: 20px;
}
.face-count {
display: inline-block;
background: var(--primary-color);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.9rem;
font-weight: 500;
margin-left: 8px;
}
.face-card {
background: var(--background-dark);
border-radius: 8px;
padding: 20px;
margin-top: 15px;
border: 1px solid var(--border-color);
display: flex;
gap: 20px;
align-items: flex-start;
}
.face-header {
font-size: 1rem;
font-weight: 600;
margin-bottom: 20px;
color: var(--text-primary);
}
.face-image-left {
flex-shrink: 0;
width: 336px;
height: 336px;
background: var(--background-darker);
border-radius: 8px;
overflow: hidden;
border: 1px solid var(--border-color);
}
.face-image-left img {
width: 100%;
height: 100%;
object-fit: cover;
}
.face-predictions-right {
flex: 1;
display: flex;
flex-direction: column;
gap: 10px;
}
.predictions-horizontal {
display: flex;
flex-direction: row;
gap: 30px;
justify-content: space-between;
}
.prediction-section {
flex: 1;
min-width: 0;
}
.prediction-category-label {
font-size: 0.8rem;
font-weight: 700;
text-transform: uppercase;
letter-spacing: 0.5px;
color: var(--primary-color);
margin-bottom: 8px;
border-bottom: 2px solid var(--primary-color);
padding-bottom: 4px;
}
.probabilities-list {
display: flex;
flex-direction: column;
gap: 6px;
}
.probability-item {
display: grid;
grid-template-columns: 70px 1fr 55px;
align-items: center;
gap: 8px;
padding: 4px 6px;
border-radius: 4px;
}
.probability-item.predicted {
background: rgba(79, 70, 229, 0.2);
border-left: 3px solid var(--primary-color);
padding-left: 8px;
}
.prob-class {
font-size: 0.8rem;
font-weight: 600;
color: var(--text-primary);
word-wrap: break-word; /* Ensure long class names wrap */
}
.probability-item.predicted .prob-class {
color: var(--primary-color);
font-weight: 700;
}
.prob-bar-container {
height: 6px;
background: var(--border-color);
border-radius: 3px;
overflow: hidden;
}
.prob-bar {
height: 100%;
background: linear-gradient(90deg, var(--primary-color), var(--success-color));
border-radius: 3px;
transition: width 0.6s ease;
}
.probability-item.predicted .prob-bar {
background: var(--primary-color);
}
.prob-percentage {
font-size: 0.75rem;
font-weight: 500;
color: var(--text-secondary);
text-align: right;
}
.probability-item.predicted .prob-percentage {
color: var(--primary-color);
font-weight: 700;
}
@media (max-width: 1200px) {
.predictions-horizontal {
flex-direction: column;
gap: 15px;
}
}
@media (max-width: 900px) {
.face-card {
flex-direction: column;
}
.face-image-left {
width: 100%;
max-width: 336px;
margin: 0 auto;
}
.probability-item {
grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
}
.prob-class {
font-size: 0.75rem;
}
}
"""
# Create interface
with gr.Blocks(css=custom_css, title="Face Classification System", theme=gr.themes.Default()) as demo:
with gr.Row():
gr.Markdown("# Face Classification System")
# --- Model Selection ---
with gr.Row():
with gr.Column(scale=3):
checkpoint_dropdown = gr.Dropdown(
label="Select Model Checkpoint",
choices=checkpoint_list,
value=default_checkpoint,
)
with gr.Column(scale=2):
model_status_text = gr.Textbox(
label="Model Status",
value=initial_status,
interactive=False,
)
# Features | Instructions
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("""
### Features
- **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average
- **Gender Classification**: M/F
- **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral)
- **Automatic Face Detection**: Detects and analyzes multiple faces
- **Detailed Probability Distributions**: View confidence for all classes
""")
with gr.Column(scale=1):
gr.Markdown("""
### Instructions
1. (Optional) Select a model checkpoint from the dropdown.
2. Upload an image or capture from webcam (or select an example below)
3. Click "Classify Image"
4. View detected faces with age, gender, and emotion predictions below
\n
Demo video of usage of this space: https://youtu.be/V6-9QTf1xaQ
""")
# Upload Image | Annotated Image
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "webcam"],
height=400
)
with gr.Column(scale=1):
output_image = gr.Image(
label="Annotated Image",
type="pil",
height=400
)
with gr.Row():
with gr.Column(scale=1):
analyze_btn = gr.Button(
"Classify Image",
variant="primary",
size="lg"
)
# Dynamically load example images from example directory
example_dir = "example"
example_images = []
if os.path.exists(example_dir):
try:
example_images = [
os.path.join(example_dir, f)
for f in sorted(os.listdir(example_dir))
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
except Exception as e:
print(f"Error reading example images from {example_dir}: {e}")
if example_images:
gr.Markdown("### ๐Ÿ“ธ Try with example images")
gr.Examples(
examples=example_images,
inputs=input_image,
cache_examples=False
)
# Results section
with gr.Row():
with gr.Column(scale=1):
output_html = gr.HTML(
label="Classification Results",
elem_classes="output-html"
)
# Event handlers
analyze_btn.click(
fn=process_image,
inputs=[input_image, checkpoint_dropdown],
outputs=[output_image, output_html]
)
checkpoint_dropdown.change(
fn=load_model_and_update_status,
inputs=[checkpoint_dropdown],
outputs=[model_status_text]
)
return demo
# Application Startup
print("="*60)
print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
print("="*60)
# --- Model download from HF Repo ---
print(f"Downloading model weights from {MODEL_REPO_ID} to {CHECKPOINTS_DIR}...")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
try:
snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=CHECKPOINTS_DIR,
allow_patterns=["*.pt", "*.pth"], # Grabs all weight files
local_dir_use_symlinks=False,
)
print("Model download complete.")
except Exception as e:
print(f"CRITICAL: Failed to download models from Hub. {e}")
traceback.print_exc()
checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
if not checkpoint_list:
print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.")
else:
print(f"Found checkpoints: {len(checkpoint_list)} file(s).")
print(f"Default checkpoint: {default_checkpoint}")
# --- Try to initialize default model ---
initial_status_msg = "No default model found. Please select one."
if default_checkpoint:
print(f"\nInitializing default model: {default_checkpoint}")
# This will load the model AND set current_ckpt_dir
# It now correctly uses the local file path
initial_status_msg = load_model_and_update_status(default_checkpoint)
print(initial_status_msg)
else:
print("Warning: No default model to load.")
# --- Create interface ---
print("Creating Gradio interface...")
demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg)
print("โœ“ Interface created successfully!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
help="Path to the checkpoint directory (will be populated from HF Hub)")
parser.add_argument("--detection_confidence", type=float, default=0.5,
help="Confidence threshold for face detection")
parser.add_argument("--port", type=int, default=7860,
help="Port to run the Gradio app")
parser.add_argument("--share", action="store_true",
help="Create a public share link")
parser.add_argument("--server_name", type=str, default="0.0.0.0",
help="Server name/IP to bind to")
args = parser.parse_args()
CHECKPOINTS_DIR = args.ckpt_dir
print(f"\nLaunching server on {args.server_name}:{args.port}")
print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
print("="*60)
demo.launch(
share=args.share,
server_name=args.server_name,
server_port=args.port,
show_error=True,
)