dimemex / src /inference.py
julianzrmrz's picture
Update src/inference.py
6bb1c3d verified
raw
history blame
4.82 kB
import torch
import easyocr
import os
import sys
from transformers import AutoTokenizer
# Configurar path para importar m贸dulos locales (nlp_utils y utils)
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRENT_DIR)
from nlp_utils import BetoClassifier
from utils import clean_text, preprocess_image_for_ocr
# --- CONFIGURACI脫N ---
MODEL_VERSION = "v4"
MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"
MAX_LEN = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Construcci贸n de la ruta
PROJECT_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))
# --- CORRECCI脫N AQU脥 ---
# Antes ten铆as: os.path.join(PROJECT_ROOT, "models", MODEL_VERSION)
# Ahora apuntamos directo a 'models', porque ah铆 viven los archivos .pth
MODEL_DIR = os.path.join(PROJECT_ROOT, "models")
class MemePredictor:
def __init__(self):
print(f"Inicializando motor en: {DEVICE}")
print(f"Buscando modelos en: {MODEL_DIR}")
# Cargar OCR (Singleton)
self.reader = easyocr.Reader(['es', 'en'], gpu=(DEVICE.type == 'cuda'))
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Cache para no recargar modelos .pth
self.loaded_models = {}
# Mapas de etiquetas
# NOTA: Aseg煤rate que coincidan con el orden de tu entrenamiento
self.labels_map = {
"simple": ["None", "Inappropriate", "Hate"],
"complex": ["None", "Inapp", "Sexism", "Racism", "Classicism", "Other"]
}
def _get_model_instance(self, task):
if task in self.loaded_models:
return self.loaded_models[task]
# Construir nombre del archivo: beto_simple_v4.pth
filename = f"beto_{task}_{MODEL_VERSION}.pth"
path = os.path.join(MODEL_DIR, filename)
if not os.path.exists(path):
raise FileNotFoundError(f"No se encontr贸 el modelo para la tarea '{task}' en {path}")
print(f"Cargando modelo desde: {path}")
n_classes = len(self.labels_map[task])
# Instanciar arquitectura
model = BetoClassifier(n_classes, MODEL_NAME)
# Cargar pesos
try:
model.load_state_dict(torch.load(path, map_location=DEVICE))
except Exception as e:
raise RuntimeError(f"Error al leer el archivo .pth: {e}")
model.to(DEVICE)
model.eval()
self.loaded_models[task] = model
return model
def predict(self, image_file, task="simple"):
# 1. Resetear puntero del archivo
image_file.seek(0)
# 2. Pre-procesamiento de Imagen (Visi贸n)
proc_img, raw_img = preprocess_image_for_ocr(image_file)
if proc_img is None:
return {"error": "Error procesando la imagen (archivo corrupto o formato inv谩lido)"}
# 3. OCR con Fallback
try:
# Intento 1: Imagen procesada
ocr_result = self.reader.readtext(proc_img, detail=0, paragraph=True)
raw_text = " ".join(ocr_result)
# Si ley贸 muy poco (<3 chars), intentar con la original
if len(raw_text) < 3:
ocr_result = self.reader.readtext(raw_img, detail=0, paragraph=True)
raw_text = " ".join(ocr_result)
except Exception as e:
return {"error": f"Fallo en OCR: {e}"}
# 4. Limpieza de Texto (NLP)
text_ready = clean_text(raw_text)
# 5. Tokenizaci贸n
encoding = self.tokenizer.encode_plus(
text_ready,
add_special_tokens=True,
max_length=MAX_LEN,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(DEVICE)
attention_mask = encoding['attention_mask'].to(DEVICE)
# 6. Inferencia del Modelo
try:
model = self._get_model_instance(task)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
conf, idx = torch.max(probs, dim=1)
label_str = self.labels_map[task][idx.item()]
return {
"ocr_text": raw_text,
"clean_text": text_ready,
"label": label_str,
"confidence": conf.item(),
"probabilities": probs.cpu().numpy()[0],
"all_labels": self.labels_map[task]
}
except Exception as e:
return {"error": f"Error en inferencia del modelo: {e}"}