import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as transforms from torchvision.models.resnet import ResNet50_Weights from PIL import Image import numpy as np import os import requests import time from pathlib import Path # Check CUDA availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Constants MODEL_URLS = { 'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt', 'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt' } IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Default transform transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), ]) normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) # Get ImageNet labels def get_imagenet_labels(): url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" response = requests.get(url) if response.status_code == 200: return response.json() else: raise RuntimeError("Failed to fetch ImageNet labels") # Download model if needed def download_model(model_type): if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: return None # Use PyTorch's pretrained model model_path = Path(f"models/{model_type}.pt") if not model_path.exists(): print(f"Downloading {model_type} model...") url = MODEL_URLS[model_type] response = requests.get(url, stream=True) if response.status_code == 200: with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Model downloaded and saved to {model_path}") else: raise RuntimeError(f"Failed to download model: {response.status_code}") return model_path class NormalizeByChannelMeanStd(nn.Module): def __init__(self, mean, std): super(NormalizeByChannelMeanStd, self).__init__() if not isinstance(mean, torch.Tensor): mean = torch.tensor(mean) if not isinstance(std, torch.Tensor): std = torch.tensor(std) self.register_buffer("mean", mean) self.register_buffer("std", std) def forward(self, tensor): return self.normalize_fn(tensor, self.mean, self.std) def normalize_fn(self, tensor, mean, std): """Differentiable version of torchvision.functional.normalize""" # here we assume the color channel is at dim=1 mean = mean[None, :, None, None] std = std[None, :, None, None] return tensor.sub(mean).div(std) class InferStep: def __init__(self, orig_image, eps, step_size): self.orig_image = orig_image self.eps = eps self.step_size = step_size def project(self, x): diff = x - self.orig_image diff = torch.clamp(diff, -self.eps, self.eps) return torch.clamp(self.orig_image + diff, 0, 1) def step(self, x, grad): l = len(x.shape) - 1 grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l)) scaled_grad = grad / (grad_norm + 1e-10) return scaled_grad * self.step_size def get_inference_configs(eps=0.5, n_itr=50): """Generate inference configuration with customizable parameters.""" config = { 'loss_infer': 'IncreaseConfidence', # How to guide the optimization 'loss_function': 'CE', # Loss function: Cross Entropy 'n_itr': n_itr, # Number of iterations 'eps': eps, # Maximum perturbation size 'step_size': 1, # Step size for each iteration 'diffusion_noise_ratio': 0.0, # No diffusion noise 'initial_inference_noise_ratio': 0.0, # No initial noise 'top_layer': 'all', # Use all layers of the model 'inference_normalization': 'off', # Apply normalization during inference 'recognition_normalization': 'off', # Apply normalization during recognition 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize } return config class GenerativeInferenceModel: def __init__(self): self.models = {} self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device) self.labels = get_imagenet_labels() def verify_model_integrity(self, model, model_type): """ Verify model integrity by running a test input through it. Returns whether the model passes basic integrity check. """ try: print(f"\n=== Running model integrity check for {model_type} ===") # Create a deterministic test input test_input = torch.zeros(1, 3, 224, 224) test_input[0, 0, 100:124, 100:124] = 0.5 # Red square test_input = test_input.to(model.device if hasattr(model, 'device') else 'cpu') # Run forward pass with torch.no_grad(): output = model(test_input) # Check output shape if output.shape != (1, 1000): print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)") return False # Get top prediction probs = torch.nn.functional.softmax(output, dim=1) confidence, prediction = torch.max(probs, 1) # Calculate basic statistics on output mean = output.mean().item() std = output.std().item() min_val = output.min().item() max_val = output.max().item() print(f"Model integrity check results:") print(f"- Output shape: {output.shape}") print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence") print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}") # Basic sanity checks if torch.isnan(output).any(): print("❌ Model produced NaN outputs") return False if output.std().item() < 0.1: print("⚠️ Low output variance, model may not be discriminative") print("✅ Model passes basic integrity check") return True except Exception as e: print(f"❌ Model integrity check failed with error: {e}") return False def load_model(self, model_type): """Load model from checkpoint or use pretrained model.""" if model_type in self.models: return self.models[model_type] model_path = download_model(model_type) # Create a sequential model with normalizer and ResNet50 resnet = models.resnet50() model = nn.Sequential( self.normalizer, # Normalizer is part of the model sequence resnet ) # Load the model checkpoint if model_path: print(f"Loading {model_type} model from {model_path}...") try: checkpoint = torch.load(model_path, map_location=device) # Print checkpoint structure for better understanding print("\n=== Analyzing checkpoint structure ===") if isinstance(checkpoint, dict): print(f"Checkpoint contains keys: {list(checkpoint.keys())}") # Examine 'model' structure if it exists if 'model' in checkpoint and isinstance(checkpoint['model'], dict): model_dict = checkpoint['model'] # Get sample of keys to understand structure first_keys = list(model_dict.keys())[:5] print(f"'model' contains keys like: {first_keys}") # Check for common prefixes in the model dict prefixes = set() for key in list(model_dict.keys())[:100]: # Check first 100 keys parts = key.split('.') if len(parts) > 1: prefixes.add(parts[0]) if prefixes: print(f"Common prefixes in model dict: {prefixes}") else: print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}") # Handle different checkpoint formats if 'model' in checkpoint: # Format from madrylab robust models state_dict = checkpoint['model'] print("Using 'model' key from checkpoint") elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] print("Using 'state_dict' key from checkpoint") else: # Direct state dict state_dict = checkpoint print("Using checkpoint directly as state_dict") # Handle prefix in state dict keys for ResNet part resnet_state_dict = {} prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.'] resnet_keys = set(resnet.state_dict().keys()) # First check if we can find keys directly in the attacker.model path print("\n=== Phase 1: Checking for specific model structures ===") # Check for 'module.model' structure (seen in actual checkpoint) module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')] if module_model_keys: print(f"Found 'module.model' structure with {len(module_model_keys)} parameters") # Extract all parameters from module.model for source_key, value in state_dict.items(): if source_key.startswith('module.model.'): target_key = source_key[len('module.model.'):] resnet_state_dict[target_key] = value print(f"Extracted {len(resnet_state_dict)} parameters from module.model") # Check for 'attacker.model' structure attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')] if attacker_model_keys: print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters") # Extract all parameters from attacker.model for source_key, value in state_dict.items(): if source_key.startswith('attacker.model.'): target_key = source_key[len('attacker.model.'):] resnet_state_dict[target_key] = value print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model") # Check if 'model' (not attacker.model) exists as a fallback model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')] if model_keys and len(resnet_state_dict) < len(resnet_keys): print(f"Found additional 'model.' structure with {len(model_keys)} parameters") # Try to complete missing parameters for source_key, value in state_dict.items(): if source_key.startswith('model.'): target_key = source_key[len('model.'):] if target_key in resnet_keys and target_key not in resnet_state_dict: resnet_state_dict[target_key] = value else: # Check for other known structures structure_found = False # Check for 'model.' prefix model_keys = [key for key in state_dict.keys() if key.startswith('model.')] if model_keys: print(f"Found 'model.' structure with {len(model_keys)} parameters") for source_key, value in state_dict.items(): if source_key.startswith('model.'): target_key = source_key[len('model.'):] resnet_state_dict[target_key] = value structure_found = True # Check for ResNet parameters at the top level top_level_resnet_keys = 0 for key in resnet_keys: if key in state_dict: top_level_resnet_keys += 1 if top_level_resnet_keys > 0: print(f"Found {top_level_resnet_keys} ResNet parameters at top level") for target_key in resnet_keys: if target_key in state_dict: resnet_state_dict[target_key] = state_dict[target_key] structure_found = True # If no structure was recognized, try the prefix mapping approach if not structure_found: print("No standard model structure found, trying prefix mappings...") for target_key in resnet_keys: for prefix in prefixes_to_try: source_key = prefix + target_key if source_key in state_dict: resnet_state_dict[target_key] = state_dict[source_key] break # If we still can't find enough keys, try a final approach of removing prefixes if len(resnet_state_dict) < len(resnet_keys): print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...") # Track matches found through prefix removal prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']} layer_matches = {} # Track matches by layer type # Count parameter keys by layer type for analysis for key in resnet_keys: layer_name = key.split('.')[0] if '.' in key else key if layer_name not in layer_matches: layer_matches[layer_name] = {'total': 0, 'matched': 0} layer_matches[layer_name]['total'] += 1 # Try keys with common prefixes for source_key, value in state_dict.items(): # Skip if already found target_key = source_key matched_prefix = None # Try removing various prefixes for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']: if source_key.startswith(prefix): target_key = source_key[len(prefix):] matched_prefix = prefix break # If the target key is in the ResNet keys, add it to the state dict if target_key in resnet_keys and target_key not in resnet_state_dict: resnet_state_dict[target_key] = value # Update match statistics if matched_prefix: prefix_matches[matched_prefix] += 1 # Update layer matches layer_name = target_key.split('.')[0] if '.' in target_key else target_key if layer_name in layer_matches: layer_matches[layer_name]['matched'] += 1 # Print detailed prefix removal statistics print("\n=== Prefix Removal Statistics ===") total_matches = sum(prefix_matches.values()) print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)") # Show matches by prefix print("\nMatches by prefix:") for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True): if count > 0: print(f" {prefix}: {count} parameters") # Show matches by layer type print("\nMatches by layer type:") for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True): match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0 print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)") # Check for specific important layers (conv1, layer1, etc.) critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc'] print("\nStatus of critical layers:") for layer in critical_layers: if layer in layer_matches: match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100 status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE" print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}") else: print(f" {layer}: Not found in model") # Load the ResNet state dict if resnet_state_dict: try: # Use strict=False to allow missing keys result = resnet.load_state_dict(resnet_state_dict, strict=False) missing_keys, unexpected_keys = result # Generate detailed information with better formatting loading_report = [] loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====") loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}") loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}") loading_report.append(f"Missing keys: {len(missing_keys):,} parameters") loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters") # Calculate percentage of parameters loaded loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys) loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100 # Determine loading success status if loaded_percent >= 99.5: status = "✅ COMPLETE - All important parameters loaded" elif loaded_percent >= 90: status = "🟡 PARTIAL - Most parameters loaded, should still function" elif loaded_percent >= 50: status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly" else: status = "❌ FAILED - Critical parameters missing, will not function properly" loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)") loading_report.append(f"Loading status: {status}") # If loading is severely incomplete, fall back to PyTorch's pretrained model if loaded_percent < 50: loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.") loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.") # Create a new ResNet model with pretrained weights resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(self.normalizer, resnet) loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model") # Show missing keys by layer type if missing_keys: loading_report.append("\nMissing keys by layer type:") layer_types = {} for key in missing_keys: # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.) parts = key.split('.') if len(parts) > 0: layer_type = parts[0] if layer_type not in layer_types: layer_types[layer_type] = 0 layer_types[layer_type] += 1 # Add counts by layer type for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True): loading_report.append(f" {layer_type}: {count:,} parameters") loading_report.append("\nFirst 10 missing keys:") for i, key in enumerate(sorted(missing_keys)[:10]): loading_report.append(f" {i+1}. {key}") # Show unexpected keys if any if unexpected_keys: loading_report.append("\nFirst 10 unexpected keys:") for i, key in enumerate(sorted(unexpected_keys)[:10]): loading_report.append(f" {i+1}. {key}") loading_report.append("========================================") # Convert report to string and print it report_text = "\n".join(loading_report) print(report_text) # Also save to a file for reference os.makedirs("logs", exist_ok=True) with open(f"logs/model_loading_{model_type}.log", "w") as f: f.write(report_text) # Look for normalizer parameters as well if any(key.startswith('attacker.normalize.') for key in state_dict.keys()): norm_state_dict = {} for key, value in state_dict.items(): if key.startswith('attacker.normalize.'): norm_key = key[len('attacker.normalize.'):] norm_state_dict[norm_key] = value if norm_state_dict: try: self.normalizer.load_state_dict(norm_state_dict, strict=False) print("Successfully loaded normalizer parameters") except Exception as e: print(f"Warning: Could not load normalizer parameters: {e}") except Exception as e: print(f"Warning: Error loading ResNet parameters: {e}") # Fall back to loading without normalizer model = resnet # Use just the ResNet model without normalizer except Exception as e: print(f"Error loading model checkpoint: {e}") # Fallback to PyTorch's pretrained model print("Falling back to PyTorch's pretrained model") resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(self.normalizer, resnet) else: # Fallback to PyTorch's pretrained model print("No checkpoint available, using PyTorch's pretrained model") resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = nn.Sequential(self.normalizer, resnet) model = model.to(device) model.eval() # Set to evaluation mode # Verify model integrity self.verify_model_integrity(model, model_type) # Store the model for future use self.models[model_type] = model return model def inference(self, image, model_type, config): """Run generative inference on the image.""" # Load model if not already loaded model = self.load_model(model_type) # Check if image is a file path if isinstance(image, str): if os.path.exists(image): image = Image.open(image).convert('RGB') else: raise ValueError(f"Image path does not exist: {image}") # Prepare image tensor image_tensor = transform(image).unsqueeze(0).to(device) image_tensor.requires_grad = True # Check model structure is_sequential = isinstance(model, nn.Sequential) # Get original predictions with torch.no_grad(): # If the model is sequential with a normalizer, skip the normalization step if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): print("Model is sequential with normalization") output_original = model(image_tensor) # Model includes normalization # Get the core model part (typically at index 1 in Sequential) core_model = model[1] else: print("Model is not sequential with normalization") # Use manual normalization for non-sequential models normalized_tensor = normalize_transform(image_tensor) output_original = model(normalized_tensor) core_model = model probs_orig = F.softmax(output_original, dim=1) conf_orig, classes_orig = torch.max(probs_orig, 1) # Get least confident classes _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False) # Initialize inference step infer_step = InferStep(image_tensor, config['eps'], config['step_size']) # Storage for inference steps # Create a new tensor that requires gradients x = image_tensor.clone().detach().requires_grad_(True) all_steps = [image_tensor[0].detach().cpu()] # Main inference loop for i in range(config['n_itr']): # Reset gradients x.grad = None # Forward pass if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): output = model(x) # Model includes normalization else: # Use manual normalization for non-sequential models normalized_x = normalize_transform(x) output = model(normalized_x) # Calculate loss to maximize confidence for least confident classes try: # Get the least confident classes num_classes = min(10, least_confident_classes.size(1)) target_classes = least_confident_classes[0, :num_classes] # Create a combined loss (avoid accumulating in a loop) targets = torch.tensor([idx.item() for idx in target_classes], device=device) # Method 1: Use a single combined loss loss = 0 for target in targets: # Create one-hot target one_hot = torch.zeros_like(output) one_hot[0, target] = 1 # Use negative loss to maximize confidence loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot) # Method 2: Try direct gradient calculation # Instead of loss.backward(), which might be failing grad = torch.autograd.grad(loss, x, retain_graph=True)[0] if grad is None: print("Warning: Direct gradient calculation failed") # Fall back to random perturbation random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] x = x + random_noise else: # Update image with gradient step = infer_step.step(x, grad) x = x + step x = infer_step.project(x) except Exception as e: print(f"Error in gradient calculation: {e}") # Fall back to random perturbation random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] x = x + random_noise x = infer_step.project(x) # Store step if in iterations_to_show if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']: all_steps.append(x[0].detach().cpu()) # Print some info about the inference with torch.no_grad(): if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): final_output = model(x) else: normalized_x = normalize_transform(x) final_output = model(normalized_x) final_probs = F.softmax(final_output, dim=1) final_conf, final_classes = torch.max(final_probs, 1) print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})") print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})") # Return final image and all stored steps return x[0].detach().cpu(), all_steps # Utility function to show inference steps def show_inference_steps(steps, figsize=(15, 10)): import matplotlib.pyplot as plt n_steps = len(steps) fig, axes = plt.subplots(1, n_steps, figsize=figsize) for i, step_img in enumerate(steps): img = step_img.permute(1, 2, 0).numpy() axes[i].imshow(img) axes[i].set_title(f"Step {i}") axes[i].axis('off') plt.tight_layout() return fig