julianzrmrz commited on
Commit
6bb1c3d
verified
1 Parent(s): d84c710

Update src/inference.py

Browse files
Files changed (1) hide show
  1. src/inference.py +10 -7
src/inference.py CHANGED
@@ -11,20 +11,24 @@ sys.path.append(CURRENT_DIR)
11
  from nlp_utils import BetoClassifier
12
  from utils import clean_text, preprocess_image_for_ocr
13
 
 
14
  MODEL_VERSION = "v4"
15
  MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"
16
  MAX_LEN = 128
17
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
- # Construcci贸n de la ruta: .../DIMEMEX/models/v4/
20
- # Subimos dos niveles desde src/inference.py para llegar a la raiz, luego models, luego v4
21
  PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))
22
- MODEL_DIR = os.path.join(PROJECT_ROOT, "models", MODEL_VERSION)
 
 
 
 
23
 
24
  class MemePredictor:
25
  def __init__(self):
26
  print(f"Inicializando motor en: {DEVICE}")
27
- print(f"Buscando modelos versi贸n {MODEL_VERSION} en: {MODEL_DIR}")
28
 
29
  # Cargar OCR (Singleton)
30
  self.reader = easyocr.Reader(['es', 'en'], gpu=(DEVICE.type == 'cuda'))
@@ -34,13 +38,13 @@ class MemePredictor:
34
  self.loaded_models = {}
35
 
36
  # Mapas de etiquetas
 
37
  self.labels_map = {
38
  "simple": ["None", "Inappropriate", "Hate"],
39
- "complex": ["None", "Inapp", "Sexism", "Racism", "Classicism", "Hate"]
40
  }
41
 
42
  def _get_model_instance(self, task):
43
- # Si ya est谩 en RAM, devolverlo
44
  if task in self.loaded_models:
45
  return self.loaded_models[task]
46
 
@@ -59,7 +63,6 @@ class MemePredictor:
59
  model = BetoClassifier(n_classes, MODEL_NAME)
60
 
61
  # Cargar pesos
62
- # map_location es vital para evitar errores si entrenaste en GPU y corres en CPU
63
  try:
64
  model.load_state_dict(torch.load(path, map_location=DEVICE))
65
  except Exception as e:
 
11
  from nlp_utils import BetoClassifier
12
  from utils import clean_text, preprocess_image_for_ocr
13
 
14
+ # --- CONFIGURACI脫N ---
15
  MODEL_VERSION = "v4"
16
  MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"
17
  MAX_LEN = 128
18
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
+ # Construcci贸n de la ruta
 
21
  PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))
22
+
23
+ # --- CORRECCI脫N AQU脥 ---
24
+ # Antes ten铆as: os.path.join(PROJECT_ROOT, "models", MODEL_VERSION)
25
+ # Ahora apuntamos directo a 'models', porque ah铆 viven los archivos .pth
26
+ MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
27
 
28
  class MemePredictor:
29
  def __init__(self):
30
  print(f"Inicializando motor en: {DEVICE}")
31
+ print(f"Buscando modelos en: {MODEL_DIR}")
32
 
33
  # Cargar OCR (Singleton)
34
  self.reader = easyocr.Reader(['es', 'en'], gpu=(DEVICE.type == 'cuda'))
 
38
  self.loaded_models = {}
39
 
40
  # Mapas de etiquetas
41
+ # NOTA: Aseg煤rate que coincidan con el orden de tu entrenamiento
42
  self.labels_map = {
43
  "simple": ["None", "Inappropriate", "Hate"],
44
+ "complex": ["None", "Inapp", "Sexism", "Racism", "Classicism", "Other"]
45
  }
46
 
47
  def _get_model_instance(self, task):
 
48
  if task in self.loaded_models:
49
  return self.loaded_models[task]
50
 
 
63
  model = BetoClassifier(n_classes, MODEL_NAME)
64
 
65
  # Cargar pesos
 
66
  try:
67
  model.load_state_dict(torch.load(path, map_location=DEVICE))
68
  except Exception as e: