| """ | |
| Model loading and device utilities. | |
| """ | |
| import torch | |
| from model import HAT | |
| from config import MODEL_CHECKPOINT, MODEL_CONFIG | |
| def get_device(): | |
| """Get the appropriate device for model inference.""" | |
| return torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def load_model(): | |
| """Load and initialize the HAT model with pre-trained weights.""" | |
| device = get_device() | |
| # Initialize model | |
| model = HAT(**MODEL_CONFIG) | |
| # Load the fine-tuned weights | |
| checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device) | |
| # Try different checkpoint formats | |
| state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model, device |