|
|
""" |
|
|
Residual Convolutional Autoencoder Model |
|
|
|
|
|
Usage: |
|
|
from model import ResidualConvAutoencoder, load_model |
|
|
import torch |
|
|
|
|
|
# Option 1: Create and load manually |
|
|
model = ResidualConvAutoencoder(latent_dim=512) |
|
|
checkpoint = torch.load('model_universal_best.ckpt') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Option 2: Use helper function |
|
|
model, checkpoint = load_model('model_universal_best.ckpt', device='cuda') |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
"""Residual block with two convolutional layers and optional dropout""" |
|
|
def __init__(self, channels, dropout=0.1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(channels) |
|
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
out = self.relu(self.bn1(self.conv1(x))) |
|
|
out = self.dropout(out) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += residual |
|
|
return self.relu(out) |
|
|
|
|
|
|
|
|
class ResidualConvAutoencoder(nn.Module): |
|
|
""" |
|
|
Residual Convolutional Autoencoder for image reconstruction |
|
|
|
|
|
Args: |
|
|
latent_dim (int): Dimension of the latent space. Default: 512 |
|
|
dropout (float): Dropout rate for regularization. Default: 0.1 |
|
|
|
|
|
Input: |
|
|
x: Tensor of shape (batch_size, 3, 128, 128) |
|
|
Values should be normalized to [-1, 1] |
|
|
|
|
|
Output: |
|
|
reconstructed: Tensor of shape (batch_size, 3, 128, 128) |
|
|
latent: Tensor of shape (batch_size, latent_dim) |
|
|
""" |
|
|
def __init__(self, latent_dim=512, dropout=0.1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(3, 64, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(64, dropout), |
|
|
|
|
|
nn.Conv2d(64, 128, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(128, dropout), |
|
|
|
|
|
nn.Conv2d(128, 256, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(256, dropout), |
|
|
|
|
|
nn.Conv2d(256, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(512, dropout), |
|
|
|
|
|
nn.Conv2d(512, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
|
|
|
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) |
|
|
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) |
|
|
|
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(512, dropout), |
|
|
|
|
|
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(256, dropout), |
|
|
|
|
|
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(128, dropout), |
|
|
|
|
|
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
ResidualBlock(64, dropout), |
|
|
|
|
|
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass through the autoencoder |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch_size, 3, 128, 128) |
|
|
|
|
|
Returns: |
|
|
reconstructed: Reconstructed tensor of shape (batch_size, 3, 128, 128) |
|
|
latent: Latent representation of shape (batch_size, latent_dim) |
|
|
""" |
|
|
|
|
|
x = self.encoder(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
latent = self.fc_encoder(x) |
|
|
|
|
|
|
|
|
x = self.fc_decoder(latent) |
|
|
x = x.view(x.size(0), 512, 4, 4) |
|
|
reconstructed = self.decoder(x) |
|
|
|
|
|
return reconstructed, latent |
|
|
|
|
|
def reconstruction_error(self, x): |
|
|
""" |
|
|
Compute per-sample reconstruction error (MSE) |
|
|
|
|
|
Args: |
|
|
x: Input tensor of shape (batch_size, 3, 128, 128) |
|
|
|
|
|
Returns: |
|
|
error: Tensor of shape (batch_size,) containing MSE for each sample |
|
|
""" |
|
|
reconstructed, _ = self.forward(x) |
|
|
error = ((reconstructed - x) ** 2).view(x.size(0), -1).mean(dim=1) |
|
|
return error |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path, device='cuda', dropout=0.1): |
|
|
""" |
|
|
Load a pretrained model from checkpoint |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint file |
|
|
device: Device to load the model on ('cuda' or 'cpu') |
|
|
dropout: Dropout rate (must match training config) |
|
|
|
|
|
Returns: |
|
|
model: Loaded ResidualConvAutoencoder model in eval mode |
|
|
checkpoint: Full checkpoint dictionary with metadata |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
config = checkpoint.get('config', {}) |
|
|
latent_dim = config.get('latent_dim', 512) |
|
|
dropout = config.get('dropout', dropout) |
|
|
|
|
|
model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
return model, checkpoint |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model = ResidualConvAutoencoder(latent_dim=512, dropout=0.1) |
|
|
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") |
|
|
|
|
|
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
|
reconstructed, latent = model(x) |
|
|
print(f"Input shape: {x.shape}") |
|
|
print(f"Reconstructed shape: {reconstructed.shape}") |
|
|
print(f"Latent shape: {latent.shape}") |
|
|
|
|
|
|
|
|
error = model.reconstruction_error(x) |
|
|
print(f"Reconstruction error shape: {error.shape}") |
|
|
print(f"Mean error: {error.mean().item():.6f}") |
|
|
|