""" Model loading and device utilities. """ import torch from model import HAT from config import MODEL_CHECKPOINT, MODEL_CONFIG def get_device(): """Get the appropriate device for model inference.""" return torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(): """Load and initialize the HAT model with pre-trained weights.""" device = get_device() # Initialize model model = HAT(**MODEL_CONFIG) # Load the fine-tuned weights checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device) # Try different checkpoint formats state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint model.load_state_dict(state_dict) model.to(device) model.eval() return model, device