""" Residual Convolutional Autoencoder Model Usage: from model import ResidualConvAutoencoder, load_model import torch # Option 1: Create and load manually model = ResidualConvAutoencoder(latent_dim=512) checkpoint = torch.load('model_universal_best.ckpt') model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Option 2: Use helper function model, checkpoint = load_model('model_universal_best.ckpt', device='cuda') """ import torch import torch.nn as nn class ResidualBlock(nn.Module): """Residual block with two convolutional layers and optional dropout""" def __init__(self, channels, dropout=0.1): super().__init__() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() def forward(self, x): residual = x out = self.relu(self.bn1(self.conv1(x))) out = self.dropout(out) out = self.bn2(self.conv2(out)) out += residual return self.relu(out) class ResidualConvAutoencoder(nn.Module): """ Residual Convolutional Autoencoder for image reconstruction Args: latent_dim (int): Dimension of the latent space. Default: 512 dropout (float): Dropout rate for regularization. Default: 0.1 Input: x: Tensor of shape (batch_size, 3, 128, 128) Values should be normalized to [-1, 1] Output: reconstructed: Tensor of shape (batch_size, 3, 128, 128) latent: Tensor of shape (batch_size, latent_dim) """ def __init__(self, latent_dim=512, dropout=0.1): super().__init__() # Encoder: 128x128 -> 4x4 self.encoder = nn.Sequential( nn.Conv2d(3, 64, 4, stride=2, padding=1), # 128 -> 64 nn.BatchNorm2d(64), nn.ReLU(inplace=True), ResidualBlock(64, dropout), nn.Conv2d(64, 128, 4, stride=2, padding=1), # 64 -> 32 nn.BatchNorm2d(128), nn.ReLU(inplace=True), ResidualBlock(128, dropout), nn.Conv2d(128, 256, 4, stride=2, padding=1), # 32 -> 16 nn.BatchNorm2d(256), nn.ReLU(inplace=True), ResidualBlock(256, dropout), nn.Conv2d(256, 512, 4, stride=2, padding=1), # 16 -> 8 nn.BatchNorm2d(512), nn.ReLU(inplace=True), ResidualBlock(512, dropout), nn.Conv2d(512, 512, 4, stride=2, padding=1), # 8 -> 4 nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) # Bottleneck self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) # Decoder: 4x4 -> 128x128 self.decoder = nn.Sequential( nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), # 4 -> 8 nn.BatchNorm2d(512), nn.ReLU(inplace=True), ResidualBlock(512, dropout), nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8 -> 16 nn.BatchNorm2d(256), nn.ReLU(inplace=True), ResidualBlock(256, dropout), nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16 -> 32 nn.BatchNorm2d(128), nn.ReLU(inplace=True), ResidualBlock(128, dropout), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32 -> 64 nn.BatchNorm2d(64), nn.ReLU(inplace=True), ResidualBlock(64, dropout), nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # 64 -> 128 nn.Tanh() ) def forward(self, x): """ Forward pass through the autoencoder Args: x: Input tensor of shape (batch_size, 3, 128, 128) Returns: reconstructed: Reconstructed tensor of shape (batch_size, 3, 128, 128) latent: Latent representation of shape (batch_size, latent_dim) """ # Encode x = self.encoder(x) x = x.view(x.size(0), -1) latent = self.fc_encoder(x) # Decode x = self.fc_decoder(latent) x = x.view(x.size(0), 512, 4, 4) reconstructed = self.decoder(x) return reconstructed, latent def reconstruction_error(self, x): """ Compute per-sample reconstruction error (MSE) Args: x: Input tensor of shape (batch_size, 3, 128, 128) Returns: error: Tensor of shape (batch_size,) containing MSE for each sample """ reconstructed, _ = self.forward(x) error = ((reconstructed - x) ** 2).view(x.size(0), -1).mean(dim=1) return error def load_model(checkpoint_path, device='cuda', dropout=0.1): """ Load a pretrained model from checkpoint Args: checkpoint_path: Path to the checkpoint file device: Device to load the model on ('cuda' or 'cpu') dropout: Dropout rate (must match training config) Returns: model: Loaded ResidualConvAutoencoder model in eval mode checkpoint: Full checkpoint dictionary with metadata """ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Get config if available config = checkpoint.get('config', {}) latent_dim = config.get('latent_dim', 512) dropout = config.get('dropout', dropout) model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() return model, checkpoint if __name__ == "__main__": # Test the model model = ResidualConvAutoencoder(latent_dim=512, dropout=0.1) print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") # Test forward pass x = torch.randn(2, 3, 128, 128) reconstructed, latent = model(x) print(f"Input shape: {x.shape}") print(f"Reconstructed shape: {reconstructed.shape}") print(f"Latent shape: {latent.shape}") # Test reconstruction error error = model.reconstruction_error(x) print(f"Reconstruction error shape: {error.shape}") print(f"Mean error: {error.mean().item():.6f}")