Antuke commited on
Commit
576fd2d
·
1 Parent(s): fbe756a
Files changed (3) hide show
  1. .gitattributes +37 -39
  2. app.py +59 -26
  3. requirements.txt +2 -1
.gitattributes CHANGED
@@ -1,39 +1,37 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- checkpoints/dora.pt filter=lfs diff=lfs merge=lfs -text
37
- checkpoints/mtlora.pt filter=lfs diff=lfs merge=lfs -text
38
- utils/res10_300x300_ssd_iter_140000_fp16.caffemodel filter=lfs diff=lfs merge=lfs -text
39
- core/vision_encoder/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ utils/res10_300x300_ssd_iter_140000_fp16.caffemodel filter=lfs diff=lfs merge=lfs -text
37
+ core/vision_encoder/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
 
 
app.py CHANGED
@@ -11,7 +11,7 @@ from PIL import Image, ImageDraw, ImageFont
11
  import base64
12
  from io import BytesIO
13
  import traceback # Import traceback at the top
14
-
15
  from utils.face_detector import FaceDetector
16
 
17
  # Class definitions
@@ -36,8 +36,9 @@ model = None
36
  transform = None
37
  detector = None
38
  device = None
39
- current_ckpt_dir = None
40
  CHECKPOINTS_DIR = './checkpoints/'
 
41
 
42
  def scan_checkpoints(ckpt_dir):
43
  """Scans a directory for .pt or .pth files."""
@@ -49,6 +50,7 @@ def scan_checkpoints(ckpt_dir):
49
  ckpt_files = [
50
  os.path.join(ckpt_dir, f)
51
  for f in sorted(os.listdir(ckpt_dir))
 
52
  ]
53
  except Exception as e:
54
  print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
@@ -78,33 +80,41 @@ def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-C
78
  model.load_model(filepath=ckpt_dir,map_location=device)
79
  return model,transform
80
 
81
- def load_model_and_update_status(ckpt_dir):
82
  """
83
- Wrapper function to load a model and return a status message.
84
- This is used by the dropdown's 'change' event.
85
  """
86
  global model, current_ckpt_dir
87
 
88
- if ckpt_dir is None or ckpt_dir == "":
89
  return "No checkpoint selected."
90
 
91
- if model is not None and ckpt_dir == current_ckpt_dir:
92
- status = f"Model already loaded: {os.path.basename(ckpt_dir)}"
 
93
  print(status)
94
  return status
95
 
96
- gr.Info(f"Loading model: {os.path.basename(ckpt_dir)}...")
97
  try:
98
- init_model(ckpt_dir=ckpt_dir, detection_confidence=0.5)
99
- current_ckpt_dir = ckpt_dir # Set global directory on successful load
100
- status = f"Successfully loaded: {os.path.basename(ckpt_dir)}"
 
 
 
 
 
101
  gr.Info("Model loaded successfully!")
102
  print(status)
103
  return status
 
104
  except Exception as e:
105
- status = f"Failed to load {os.path.basename(ckpt_dir)}: {str(e)}"
106
- print(status)
107
  traceback.print_exc()
 
 
 
108
  return status
109
 
110
  def predict(model, image):
@@ -182,10 +192,10 @@ def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
182
 
183
  # Verify model weights exist
184
  if not os.path.exists(ckpt_dir):
185
- error_msg = f"Model weights not found: {ckpt_dir}."
186
- print(f"ERROR: {error_msg}")
187
- raise FileNotFoundError(error_msg)
188
-
189
  print(f"Model weights found: {ckpt_dir}")
190
 
191
  # Load the perception encoder
@@ -214,7 +224,11 @@ def process_image(image, selected_checkpoint_path):
214
  return None, "<p style='color: red;'>Please upload an image</p>"
215
 
216
  # Ensure model is initialized
 
 
 
217
  if model is None or selected_checkpoint_path != current_ckpt_dir:
 
218
  status = load_model_and_update_status(selected_checkpoint_path)
219
  if "Failed" in status or "Error" in status:
220
  return image, f"<p style'color: red;'>Model Error: {status}</p>"
@@ -621,8 +635,8 @@ def create_interface(checkpoint_list, default_checkpoint, initial_status):
621
  with gr.Column(scale=3):
622
  checkpoint_dropdown = gr.Dropdown(
623
  label="Select Model Checkpoint",
624
- choices=checkpoint_list,
625
- value=default_checkpoint,
626
  )
627
  with gr.Column(scale=2):
628
  model_status_text = gr.Textbox(
@@ -710,13 +724,13 @@ def create_interface(checkpoint_list, default_checkpoint, initial_status):
710
  # Event handlers
711
  analyze_btn.click(
712
  fn=process_image,
713
- inputs=[input_image, checkpoint_dropdown], # Pass dropdown value
714
  outputs=[output_image, output_html]
715
  )
716
 
717
  checkpoint_dropdown.change(
718
  fn=load_model_and_update_status,
719
- inputs=[checkpoint_dropdown],
720
  outputs=[model_status_text]
721
  )
722
 
@@ -725,12 +739,29 @@ def create_interface(checkpoint_list, default_checkpoint, initial_status):
725
 
726
  # === Main Application Startup ===
727
 
728
- # Initialize for Hugging Face Spaces (module-level)
729
  print("="*60)
730
  print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
731
  print("="*60)
732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  # --- 1. Scan for models first ---
 
734
  checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
735
 
736
  if not checkpoint_list:
@@ -744,6 +775,7 @@ initial_status_msg = "No default model found. Please select one."
744
  if default_checkpoint:
745
  print(f"\nInitializing default model: {default_checkpoint}")
746
  # This will load the model AND set current_ckpt_dir
 
747
  initial_status_msg = load_model_and_update_status(default_checkpoint)
748
  print(initial_status_msg)
749
  else:
@@ -760,8 +792,9 @@ if __name__ == "__main__":
760
  import argparse
761
 
762
  parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
 
763
  parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
764
- help="Path to the checkpoint directory (overridden by UI)")
765
  parser.add_argument("--detection_confidence", type=float, default=0.5,
766
  help="Confidence threshold for face detection")
767
  parser.add_argument("--port", type=int, default=7860,
@@ -772,9 +805,9 @@ if __name__ == "__main__":
772
  help="Server name/IP to bind to")
773
  args = parser.parse_args()
774
 
775
- # Update global config if args are provided (though UI dropdown is primary)
776
  CHECKPOINTS_DIR = args.ckpt_dir
777
- # Note: detection_confidence is passed to init_model, so it's handled.
778
 
779
  print(f"\nLaunching server on {args.server_name}:{args.port}")
780
  print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
 
11
  import base64
12
  from io import BytesIO
13
  import traceback # Import traceback at the top
14
+ from huggingface_hub import snapshot_download # Use snapshot_download for startup
15
  from utils.face_detector import FaceDetector
16
 
17
  # Class definitions
 
36
  transform = None
37
  detector = None
38
  device = None
39
+ current_ckpt_dir = None # This will now store the full path to the loaded model
40
  CHECKPOINTS_DIR = './checkpoints/'
41
+ MODEL_REPO_ID = "Antuke/FaR-FT-PE"
42
 
43
  def scan_checkpoints(ckpt_dir):
44
  """Scans a directory for .pt or .pth files."""
 
50
  ckpt_files = [
51
  os.path.join(ckpt_dir, f)
52
  for f in sorted(os.listdir(ckpt_dir))
53
+ if f.endswith(('.pt', '.pth')) # Ensure we only scan for model files
54
  ]
55
  except Exception as e:
56
  print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
 
80
  model.load_model(filepath=ckpt_dir,map_location=device)
81
  return model,transform
82
 
83
+ def load_model_and_update_status(model_filepath):
84
  """
85
+ Wrapper function to load a model *from a local file path* and return a status.
86
+ The file path is provided by the dropdown.
87
  """
88
  global model, current_ckpt_dir
89
 
90
+ if model_filepath is None or model_filepath == "":
91
  return "No checkpoint selected."
92
 
93
+ # Check if this model *filepath* is already loaded
94
+ if model is not None and model_filepath == current_ckpt_dir:
95
+ status = f"Model already loaded: {os.path.basename(model_filepath)}"
96
  print(status)
97
  return status
98
 
99
+ gr.Info(f"Loading model: {os.path.basename(model_filepath)}...")
100
  try:
101
+ # --- This is the new logic ---
102
+ # The file is already local. Just initialize it.
103
+ # The 'model_filepath' *is* the ckpt_dir for init_model
104
+ init_model(ckpt_dir=model_filepath, detection_confidence=0.5)
105
+ # --- End of new logic ---
106
+
107
+ current_ckpt_dir = model_filepath # Set global path on successful load
108
+ status = f"Successfully loaded: {os.path.basename(model_filepath)}"
109
  gr.Info("Model loaded successfully!")
110
  print(status)
111
  return status
112
+
113
  except Exception as e:
 
 
114
  traceback.print_exc()
115
+ status = f"Failed to load {os.path.basename(model_filepath)}: {e}"
116
+ gr.Info(f"Error: {status}")
117
+ print(f"ERROR: {status}")
118
  return status
119
 
120
  def predict(model, image):
 
192
 
193
  # Verify model weights exist
194
  if not os.path.exists(ckpt_dir):
195
+ error_msg = f"Model weights not found: {ckpt_dir}."
196
+ print(f"ERROR: {error_msg}")
197
+ raise FileNotFoundError(error_msg)
198
+
199
  print(f"Model weights found: {ckpt_dir}")
200
 
201
  # Load the perception encoder
 
224
  return None, "<p style='color: red;'>Please upload an image</p>"
225
 
226
  # Ensure model is initialized
227
+ # This check is crucial. If the user changes the dropdown, it triggers
228
+ # load_model_and_update_status. If they just hit "Classify",
229
+ # this check ensures the selected model is loaded.
230
  if model is None or selected_checkpoint_path != current_ckpt_dir:
231
+ print(f"Model mismatch or not loaded. Selected: {selected_checkpoint_path}, Current: {current_ckpt_dir}")
232
  status = load_model_and_update_status(selected_checkpoint_path)
233
  if "Failed" in status or "Error" in status:
234
  return image, f"<p style'color: red;'>Model Error: {status}</p>"
 
635
  with gr.Column(scale=3):
636
  checkpoint_dropdown = gr.Dropdown(
637
  label="Select Model Checkpoint",
638
+ choices=checkpoint_list, # This is now a list of (label, path) tuples
639
+ value=default_checkpoint, # This is the full path to the default
640
  )
641
  with gr.Column(scale=2):
642
  model_status_text = gr.Textbox(
 
724
  # Event handlers
725
  analyze_btn.click(
726
  fn=process_image,
727
+ inputs=[input_image, checkpoint_dropdown], # Pass dropdown value (which is the path)
728
  outputs=[output_image, output_html]
729
  )
730
 
731
  checkpoint_dropdown.change(
732
  fn=load_model_and_update_status,
733
+ inputs=[checkpoint_dropdown], # Pass dropdown value (which is the path)
734
  outputs=[model_status_text]
735
  )
736
 
 
739
 
740
  # === Main Application Startup ===
741
 
 
742
  print("="*60)
743
  print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
744
  print("="*60)
745
 
746
+ # --- NEW: Download models BEFORE scanning ---
747
+ print(f"Downloading model weights from {MODEL_REPO_ID} to {CHECKPOINTS_DIR}...")
748
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
749
+ try:
750
+ snapshot_download(
751
+ repo_id=MODEL_REPO_ID,
752
+ local_dir=CHECKPOINTS_DIR,
753
+ allow_patterns=["*.pt", "*.pth"], # Grabs all weight files
754
+ local_dir_use_symlinks=False,
755
+ )
756
+ print("Model download complete.")
757
+ except Exception as e:
758
+ print(f"CRITICAL: Failed to download models from Hub. {e}")
759
+ traceback.print_exc()
760
+ # --- End of NEW code ---
761
+
762
+
763
  # --- 1. Scan for models first ---
764
+ # This will now find the files you just downloaded
765
  checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
766
 
767
  if not checkpoint_list:
 
775
  if default_checkpoint:
776
  print(f"\nInitializing default model: {default_checkpoint}")
777
  # This will load the model AND set current_ckpt_dir
778
+ # It now correctly uses the local file path
779
  initial_status_msg = load_model_and_update_status(default_checkpoint)
780
  print(initial_status_msg)
781
  else:
 
792
  import argparse
793
 
794
  parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
795
+ # ckpt_dir is now managed by the startup download, so this arg is less relevant
796
  parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
797
+ help="Path to the checkpoint directory (will be populated from HF Hub)")
798
  parser.add_argument("--detection_confidence", type=float, default=0.5,
799
  help="Confidence threshold for face detection")
800
  parser.add_argument("--port", type=int, default=7860,
 
805
  help="Server name/IP to bind to")
806
  args = parser.parse_args()
807
 
808
+ # Update global config if args are provided
809
  CHECKPOINTS_DIR = args.ckpt_dir
810
+ # Note: detection_confidence is passed to init_model during load, so it's handled.
811
 
812
  print(f"\nLaunching server on {args.server_name}:{args.port}")
813
  print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
requirements.txt CHANGED
@@ -8,4 +8,5 @@ peft
8
  python-dotenv
9
  tqdm
10
  gradio
11
- timm
 
 
8
  python-dotenv
9
  tqdm
10
  gradio
11
+ timm
12
+ huggingface-hub