HATSAT / utils /model_utils.py
BorisEm's picture
Broke down code base into smaller files for readibility
0def483
raw
history blame contribute delete
783 Bytes
"""
Model loading and device utilities.
"""
import torch
from model import HAT
from config import MODEL_CHECKPOINT, MODEL_CONFIG
def get_device():
"""Get the appropriate device for model inference."""
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model():
"""Load and initialize the HAT model with pre-trained weights."""
device = get_device()
# Initialize model
model = HAT(**MODEL_CONFIG)
# Load the fine-tuned weights
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
# Try different checkpoint formats
state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, device