ash12321's picture
Upload model.py with huggingface_hub
cc156cd verified
"""
Residual Convolutional Autoencoder Model
Usage:
from model import ResidualConvAutoencoder, load_model
import torch
# Option 1: Create and load manually
model = ResidualConvAutoencoder(latent_dim=512)
checkpoint = torch.load('model_universal_best.ckpt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Option 2: Use helper function
model, checkpoint = load_model('model_universal_best.ckpt', device='cuda')
"""
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Residual block with two convolutional layers and optional dropout"""
def __init__(self, channels, dropout=0.1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.dropout(out)
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
"""
Residual Convolutional Autoencoder for image reconstruction
Args:
latent_dim (int): Dimension of the latent space. Default: 512
dropout (float): Dropout rate for regularization. Default: 0.1
Input:
x: Tensor of shape (batch_size, 3, 128, 128)
Values should be normalized to [-1, 1]
Output:
reconstructed: Tensor of shape (batch_size, 3, 128, 128)
latent: Tensor of shape (batch_size, latent_dim)
"""
def __init__(self, latent_dim=512, dropout=0.1):
super().__init__()
# Encoder: 128x128 -> 4x4
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=1), # 128 -> 64
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64, dropout),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 64 -> 32
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128, dropout),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 32 -> 16
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256, dropout),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 16 -> 8
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512, dropout),
nn.Conv2d(512, 512, 4, stride=2, padding=1), # 8 -> 4
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# Bottleneck
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
# Decoder: 4x4 -> 128x128
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), # 4 -> 8
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
ResidualBlock(512, dropout),
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8 -> 16
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
ResidualBlock(256, dropout),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16 -> 32
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
ResidualBlock(128, dropout),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32 -> 64
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64, dropout),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # 64 -> 128
nn.Tanh()
)
def forward(self, x):
"""
Forward pass through the autoencoder
Args:
x: Input tensor of shape (batch_size, 3, 128, 128)
Returns:
reconstructed: Reconstructed tensor of shape (batch_size, 3, 128, 128)
latent: Latent representation of shape (batch_size, latent_dim)
"""
# Encode
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
# Decode
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def reconstruction_error(self, x):
"""
Compute per-sample reconstruction error (MSE)
Args:
x: Input tensor of shape (batch_size, 3, 128, 128)
Returns:
error: Tensor of shape (batch_size,) containing MSE for each sample
"""
reconstructed, _ = self.forward(x)
error = ((reconstructed - x) ** 2).view(x.size(0), -1).mean(dim=1)
return error
def load_model(checkpoint_path, device='cuda', dropout=0.1):
"""
Load a pretrained model from checkpoint
Args:
checkpoint_path: Path to the checkpoint file
device: Device to load the model on ('cuda' or 'cpu')
dropout: Dropout rate (must match training config)
Returns:
model: Loaded ResidualConvAutoencoder model in eval mode
checkpoint: Full checkpoint dictionary with metadata
"""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Get config if available
config = checkpoint.get('config', {})
latent_dim = config.get('latent_dim', 512)
dropout = config.get('dropout', dropout)
model = ResidualConvAutoencoder(latent_dim=latent_dim, dropout=dropout)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, checkpoint
if __name__ == "__main__":
# Test the model
model = ResidualConvAutoencoder(latent_dim=512, dropout=0.1)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
# Test forward pass
x = torch.randn(2, 3, 128, 128)
reconstructed, latent = model(x)
print(f"Input shape: {x.shape}")
print(f"Reconstructed shape: {reconstructed.shape}")
print(f"Latent shape: {latent.shape}")
# Test reconstruction error
error = model.reconstruction_error(x)
print(f"Reconstruction error shape: {error.shape}")
print(f"Mean error: {error.mean().item():.6f}")