julianzrmrz commited on
Commit
aa1492c
verified
1 Parent(s): 26aa008

cambios en el front

Browse files
Files changed (1) hide show
  1. src/inference.py +6 -9
src/inference.py CHANGED
@@ -11,10 +11,7 @@ sys.path.append(CURRENT_DIR)
11
  from nlp_utils import BetoClassifier
12
  from utils import clean_text, preprocess_image_for_ocr
13
 
14
- # ==========================================
15
- # 鈿欙笍 CONFIGURACI脫N DEL MODELO
16
- # ==========================================
17
- MODEL_VERSION = "v4" # <--- CAMBIA ESTO si entrenas nuevas versiones
18
  MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"
19
  MAX_LEN = 128
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -22,12 +19,12 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  # Construcci贸n de la ruta: .../DIMEMEX/models/v4/
23
  # Subimos dos niveles desde src/inference.py para llegar a la raiz, luego models, luego v4
24
  PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))
25
- MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
26
 
27
  class MemePredictor:
28
  def __init__(self):
29
- print(f"馃敡 Inicializando motor en: {DEVICE}")
30
- print(f"馃搨 Buscando modelos versi贸n {MODEL_VERSION} en: {MODEL_DIR}")
31
 
32
  # Cargar OCR (Singleton)
33
  self.reader = easyocr.Reader(['es', 'en'], gpu=(DEVICE.type == 'cuda'))
@@ -39,7 +36,7 @@ class MemePredictor:
39
  # Mapas de etiquetas
40
  self.labels_map = {
41
  "simple": ["None", "Inappropriate", "Hate"],
42
- "complex": ["None", "Inapp", "Sexism", "Racism", "Classicism", "Other"]
43
  }
44
 
45
  def _get_model_instance(self, task):
@@ -54,7 +51,7 @@ class MemePredictor:
54
  if not os.path.exists(path):
55
  raise FileNotFoundError(f"No se encontr贸 el modelo para la tarea '{task}' en {path}")
56
 
57
- print(f"馃摜 Cargando modelo desde: {path}")
58
 
59
  n_classes = len(self.labels_map[task])
60
 
 
11
  from nlp_utils import BetoClassifier
12
  from utils import clean_text, preprocess_image_for_ocr
13
 
14
+ MODEL_VERSION = "v4"
 
 
 
15
  MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"
16
  MAX_LEN = 128
17
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
  # Construcci贸n de la ruta: .../DIMEMEX/models/v4/
20
  # Subimos dos niveles desde src/inference.py para llegar a la raiz, luego models, luego v4
21
  PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))
22
+ MODEL_DIR = os.path.join(PROJECT_ROOT, "models", MODEL_VERSION)
23
 
24
  class MemePredictor:
25
  def __init__(self):
26
+ print(f"Inicializando motor en: {DEVICE}")
27
+ print(f"Buscando modelos versi贸n {MODEL_VERSION} en: {MODEL_DIR}")
28
 
29
  # Cargar OCR (Singleton)
30
  self.reader = easyocr.Reader(['es', 'en'], gpu=(DEVICE.type == 'cuda'))
 
36
  # Mapas de etiquetas
37
  self.labels_map = {
38
  "simple": ["None", "Inappropriate", "Hate"],
39
+ "complex": ["None", "Inapp", "Sexism", "Racism", "Classicism", "Hate"]
40
  }
41
 
42
  def _get_model_instance(self, task):
 
51
  if not os.path.exists(path):
52
  raise FileNotFoundError(f"No se encontr贸 el modelo para la tarea '{task}' en {path}")
53
 
54
+ print(f"Cargando modelo desde: {path}")
55
 
56
  n_classes = len(self.labels_map[task])
57