ttoosi commited on
Commit
d3dc9e2
·
verified ·
1 Parent(s): 4752d1b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +434 -49
inference.py CHANGED
@@ -10,7 +10,6 @@ import os
10
  import requests
11
  import time
12
  from pathlib import Path
13
- from spaces import GPU
14
 
15
  # Check CUDA availability
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -21,6 +20,7 @@ MODEL_URLS = {
21
  'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
22
  'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt'
23
  }
 
24
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
  IMAGENET_STD = [0.229, 0.224, 0.225]
26
 
@@ -105,12 +105,12 @@ def get_inference_configs(eps=0.5, n_itr=50):
105
  'loss_function': 'CE', # Loss function: Cross Entropy
106
  'n_itr': n_itr, # Number of iterations
107
  'eps': eps, # Maximum perturbation size
108
- 'step_size': 0.02, # Step size for each iteration
109
  'diffusion_noise_ratio': 0.0, # No diffusion noise
110
  'initial_inference_noise_ratio': 0.0, # No initial noise
111
  'top_layer': 'all', # Use all layers of the model
112
- 'inference_normalization': 'on', # Apply normalization during inference
113
- 'recognition_normalization': 'on', # Apply normalization during recognition
114
  'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize
115
  }
116
  return config
@@ -121,52 +121,384 @@ class GenerativeInferenceModel:
121
  self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
122
  self.labels = get_imagenet_labels()
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def load_model(self, model_type):
 
125
  if model_type in self.models:
126
  return self.models[model_type]
127
 
128
  model_path = download_model(model_type)
129
 
130
- # Create standard ResNet50 model
131
- model = models.resnet50()
 
 
 
 
132
 
133
  # Load the model checkpoint
134
  if model_path:
135
  print(f"Loading {model_type} model from {model_path}...")
136
- checkpoint = torch.load(model_path, map_location=device)
137
-
138
- # Handle different checkpoint formats
139
- if 'model' in checkpoint:
140
- # Format from madrylab robust models
141
- state_dict = checkpoint['model']
142
- elif 'state_dict' in checkpoint:
143
- state_dict = checkpoint['state_dict']
144
- else:
145
- # Direct state dict
146
- state_dict = checkpoint
147
-
148
- # Handle prefix in state dict keys
149
- new_state_dict = {}
150
- for key, value in state_dict.items():
151
- if key.startswith('module.'):
152
- new_key = key[7:] # Remove 'module.' prefix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  else:
154
- new_key = key
155
- new_state_dict[new_key] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- model.load_state_dict(new_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  else:
159
  # Fallback to PyTorch's pretrained model
160
- model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
 
 
161
 
162
  model = model.to(device)
163
  model.eval() # Set to evaluation mode
164
 
 
 
 
165
  # Store the model for future use
166
  self.models[model_type] = model
167
  return model
168
- @GPU
169
  def inference(self, image, model_type, config):
 
170
  # Load model if not already loaded
171
  model = self.load_model(model_type)
172
 
@@ -181,12 +513,24 @@ class GenerativeInferenceModel:
181
  image_tensor = transform(image).unsqueeze(0).to(device)
182
  image_tensor.requires_grad = True
183
 
184
- # Normalize the image for model input
185
- normalized_tensor = normalize_transform(image_tensor)
186
 
187
  # Get original predictions
188
  with torch.no_grad():
189
- output_original = model(normalized_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
190
  probs_orig = F.softmax(output_original, dim=1)
191
  conf_orig, classes_orig = torch.max(probs_orig, 1)
192
 
@@ -197,7 +541,8 @@ class GenerativeInferenceModel:
197
  infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
198
 
199
  # Storage for inference steps
200
- x = image_tensor.clone()
 
201
  all_steps = [image_tensor[0].detach().cpu()]
202
 
203
  # Main inference loop
@@ -205,32 +550,72 @@ class GenerativeInferenceModel:
205
  # Reset gradients
206
  x.grad = None
207
 
208
- # Normalize input for the model
209
- normalized_x = normalize_transform(x)
210
-
211
  # Forward pass
212
- output = model(normalized_x)
 
 
 
 
 
213
 
214
  # Calculate loss to maximize confidence for least confident classes
215
- target_classes = least_confident_classes[:10] # Use top 10 least confident classes
216
- loss = 0
217
- for idx in target_classes:
218
- target = torch.tensor([idx.item()], device=device)
219
- loss = loss - F.cross_entropy(output, target) # Negative because we want to maximize confidence
220
-
221
- # Backward pass
222
- loss.backward()
223
-
224
- # Update image
225
- with torch.no_grad():
226
- step = infer_step.step(x, x.grad)
227
- x = x + step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  x = infer_step.project(x)
229
 
230
  # Store step if in iterations_to_show
231
  if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
232
  all_steps.append(x[0].detach().cpu())
 
 
 
 
 
 
 
 
233
 
 
 
 
 
 
234
  # Return final image and all stored steps
235
  return x[0].detach().cpu(), all_steps
236
 
 
10
  import requests
11
  import time
12
  from pathlib import Path
 
13
 
14
  # Check CUDA availability
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
20
  'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
21
  'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt'
22
  }
23
+
24
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
  IMAGENET_STD = [0.229, 0.224, 0.225]
26
 
 
105
  'loss_function': 'CE', # Loss function: Cross Entropy
106
  'n_itr': n_itr, # Number of iterations
107
  'eps': eps, # Maximum perturbation size
108
+ 'step_size': 1, # Step size for each iteration
109
  'diffusion_noise_ratio': 0.0, # No diffusion noise
110
  'initial_inference_noise_ratio': 0.0, # No initial noise
111
  'top_layer': 'all', # Use all layers of the model
112
+ 'inference_normalization': 'off', # Apply normalization during inference
113
+ 'recognition_normalization': 'off', # Apply normalization during recognition
114
  'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize
115
  }
116
  return config
 
121
  self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
122
  self.labels = get_imagenet_labels()
123
 
124
+ def verify_model_integrity(self, model, model_type):
125
+ """
126
+ Verify model integrity by running a test input through it.
127
+ Returns whether the model passes basic integrity check.
128
+ """
129
+ try:
130
+ print(f"\n=== Running model integrity check for {model_type} ===")
131
+ # Create a deterministic test input
132
+ test_input = torch.zeros(1, 3, 224, 224)
133
+ test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
134
+ test_input = test_input.to(model.device if hasattr(model, 'device') else 'cpu')
135
+
136
+ # Run forward pass
137
+ with torch.no_grad():
138
+ output = model(test_input)
139
+
140
+ # Check output shape
141
+ if output.shape != (1, 1000):
142
+ print(f"❌ Unexpected output shape: {output.shape}, expected (1, 1000)")
143
+ return False
144
+
145
+ # Get top prediction
146
+ probs = torch.nn.functional.softmax(output, dim=1)
147
+ confidence, prediction = torch.max(probs, 1)
148
+
149
+ # Calculate basic statistics on output
150
+ mean = output.mean().item()
151
+ std = output.std().item()
152
+ min_val = output.min().item()
153
+ max_val = output.max().item()
154
+
155
+ print(f"Model integrity check results:")
156
+ print(f"- Output shape: {output.shape}")
157
+ print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
158
+ print(f"- Output statistics: mean={mean:.3f}, std={std:.3f}, min={min_val:.3f}, max={max_val:.3f}")
159
+
160
+ # Basic sanity checks
161
+ if torch.isnan(output).any():
162
+ print("❌ Model produced NaN outputs")
163
+ return False
164
+
165
+ if output.std().item() < 0.1:
166
+ print("⚠️ Low output variance, model may not be discriminative")
167
+
168
+ print("✅ Model passes basic integrity check")
169
+ return True
170
+
171
+ except Exception as e:
172
+ print(f"❌ Model integrity check failed with error: {e}")
173
+ return False
174
+
175
  def load_model(self, model_type):
176
+ """Load model from checkpoint or use pretrained model."""
177
  if model_type in self.models:
178
  return self.models[model_type]
179
 
180
  model_path = download_model(model_type)
181
 
182
+ # Create a sequential model with normalizer and ResNet50
183
+ resnet = models.resnet50()
184
+ model = nn.Sequential(
185
+ self.normalizer, # Normalizer is part of the model sequence
186
+ resnet
187
+ )
188
 
189
  # Load the model checkpoint
190
  if model_path:
191
  print(f"Loading {model_type} model from {model_path}...")
192
+ try:
193
+ checkpoint = torch.load(model_path, map_location=device)
194
+
195
+ # Print checkpoint structure for better understanding
196
+ print("\n=== Analyzing checkpoint structure ===")
197
+ if isinstance(checkpoint, dict):
198
+ print(f"Checkpoint contains keys: {list(checkpoint.keys())}")
199
+
200
+ # Examine 'model' structure if it exists
201
+ if 'model' in checkpoint and isinstance(checkpoint['model'], dict):
202
+ model_dict = checkpoint['model']
203
+ # Get sample of keys to understand structure
204
+ first_keys = list(model_dict.keys())[:5]
205
+ print(f"'model' contains keys like: {first_keys}")
206
+
207
+ # Check for common prefixes in the model dict
208
+ prefixes = set()
209
+ for key in list(model_dict.keys())[:100]: # Check first 100 keys
210
+ parts = key.split('.')
211
+ if len(parts) > 1:
212
+ prefixes.add(parts[0])
213
+ if prefixes:
214
+ print(f"Common prefixes in model dict: {prefixes}")
215
+ else:
216
+ print(f"Checkpoint is not a dictionary, but a {type(checkpoint)}")
217
+
218
+ # Handle different checkpoint formats
219
+ if 'model' in checkpoint:
220
+ # Format from madrylab robust models
221
+ state_dict = checkpoint['model']
222
+ print("Using 'model' key from checkpoint")
223
+ elif 'state_dict' in checkpoint:
224
+ state_dict = checkpoint['state_dict']
225
+ print("Using 'state_dict' key from checkpoint")
226
+ else:
227
+ # Direct state dict
228
+ state_dict = checkpoint
229
+ print("Using checkpoint directly as state_dict")
230
+
231
+ # Handle prefix in state dict keys for ResNet part
232
+ resnet_state_dict = {}
233
+ prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.']
234
+ resnet_keys = set(resnet.state_dict().keys())
235
+
236
+ # First check if we can find keys directly in the attacker.model path
237
+ print("\n=== Phase 1: Checking for specific model structures ===")
238
+
239
+ # Check for 'module.model' structure (seen in actual checkpoint)
240
+ module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.')]
241
+ if module_model_keys:
242
+ print(f"Found 'module.model' structure with {len(module_model_keys)} parameters")
243
+ # Extract all parameters from module.model
244
+ for source_key, value in state_dict.items():
245
+ if source_key.startswith('module.model.'):
246
+ target_key = source_key[len('module.model.'):]
247
+ resnet_state_dict[target_key] = value
248
+
249
+ print(f"Extracted {len(resnet_state_dict)} parameters from module.model")
250
+
251
+ # Check for 'attacker.model' structure
252
+ attacker_model_keys = [key for key in state_dict.keys() if key.startswith('attacker.model.')]
253
+ if attacker_model_keys:
254
+ print(f"Found 'attacker.model' structure with {len(attacker_model_keys)} parameters")
255
+ # Extract all parameters from attacker.model
256
+ for source_key, value in state_dict.items():
257
+ if source_key.startswith('attacker.model.'):
258
+ target_key = source_key[len('attacker.model.'):]
259
+ resnet_state_dict[target_key] = value
260
+
261
+ print(f"Extracted {len(resnet_state_dict)} parameters from attacker.model")
262
+
263
+ # Check if 'model' (not attacker.model) exists as a fallback
264
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.') and not key.startswith('attacker.model.')]
265
+ if model_keys and len(resnet_state_dict) < len(resnet_keys):
266
+ print(f"Found additional 'model.' structure with {len(model_keys)} parameters")
267
+ # Try to complete missing parameters
268
+ for source_key, value in state_dict.items():
269
+ if source_key.startswith('model.'):
270
+ target_key = source_key[len('model.'):]
271
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
272
+ resnet_state_dict[target_key] = value
273
+
274
  else:
275
+ # Check for other known structures
276
+ structure_found = False
277
+
278
+ # Check for 'model.' prefix
279
+ model_keys = [key for key in state_dict.keys() if key.startswith('model.')]
280
+ if model_keys:
281
+ print(f"Found 'model.' structure with {len(model_keys)} parameters")
282
+ for source_key, value in state_dict.items():
283
+ if source_key.startswith('model.'):
284
+ target_key = source_key[len('model.'):]
285
+ resnet_state_dict[target_key] = value
286
+ structure_found = True
287
+
288
+ # Check for ResNet parameters at the top level
289
+ top_level_resnet_keys = 0
290
+ for key in resnet_keys:
291
+ if key in state_dict:
292
+ top_level_resnet_keys += 1
293
+
294
+ if top_level_resnet_keys > 0:
295
+ print(f"Found {top_level_resnet_keys} ResNet parameters at top level")
296
+ for target_key in resnet_keys:
297
+ if target_key in state_dict:
298
+ resnet_state_dict[target_key] = state_dict[target_key]
299
+ structure_found = True
300
+
301
+ # If no structure was recognized, try the prefix mapping approach
302
+ if not structure_found:
303
+ print("No standard model structure found, trying prefix mappings...")
304
+ for target_key in resnet_keys:
305
+ for prefix in prefixes_to_try:
306
+ source_key = prefix + target_key
307
+ if source_key in state_dict:
308
+ resnet_state_dict[target_key] = state_dict[source_key]
309
+ break
310
 
311
+ # If we still can't find enough keys, try a final approach of removing prefixes
312
+ if len(resnet_state_dict) < len(resnet_keys):
313
+ print(f"Found only {len(resnet_state_dict)}/{len(resnet_keys)} parameters, trying prefix removal...")
314
+
315
+ # Track matches found through prefix removal
316
+ prefix_matches = {prefix: 0 for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']}
317
+ layer_matches = {} # Track matches by layer type
318
+
319
+ # Count parameter keys by layer type for analysis
320
+ for key in resnet_keys:
321
+ layer_name = key.split('.')[0] if '.' in key else key
322
+ if layer_name not in layer_matches:
323
+ layer_matches[layer_name] = {'total': 0, 'matched': 0}
324
+ layer_matches[layer_name]['total'] += 1
325
+
326
+ # Try keys with common prefixes
327
+ for source_key, value in state_dict.items():
328
+ # Skip if already found
329
+ target_key = source_key
330
+ matched_prefix = None
331
+
332
+ # Try removing various prefixes
333
+ for prefix in ['module.', 'model.', 'attacker.model.', 'attacker.']:
334
+ if source_key.startswith(prefix):
335
+ target_key = source_key[len(prefix):]
336
+ matched_prefix = prefix
337
+ break
338
+
339
+ # If the target key is in the ResNet keys, add it to the state dict
340
+ if target_key in resnet_keys and target_key not in resnet_state_dict:
341
+ resnet_state_dict[target_key] = value
342
+
343
+ # Update match statistics
344
+ if matched_prefix:
345
+ prefix_matches[matched_prefix] += 1
346
+
347
+ # Update layer matches
348
+ layer_name = target_key.split('.')[0] if '.' in target_key else target_key
349
+ if layer_name in layer_matches:
350
+ layer_matches[layer_name]['matched'] += 1
351
+
352
+ # Print detailed prefix removal statistics
353
+ print("\n=== Prefix Removal Statistics ===")
354
+ total_matches = sum(prefix_matches.values())
355
+ print(f"Total parameters matched through prefix removal: {total_matches}/{len(resnet_keys)} ({(total_matches/len(resnet_keys))*100:.1f}%)")
356
+
357
+ # Show matches by prefix
358
+ print("\nMatches by prefix:")
359
+ for prefix, count in sorted(prefix_matches.items(), key=lambda x: x[1], reverse=True):
360
+ if count > 0:
361
+ print(f" {prefix}: {count} parameters")
362
+
363
+ # Show matches by layer type
364
+ print("\nMatches by layer type:")
365
+ for layer, stats in sorted(layer_matches.items(), key=lambda x: x[1]['total'], reverse=True):
366
+ match_percent = (stats['matched'] / stats['total']) * 100 if stats['total'] > 0 else 0
367
+ print(f" {layer}: {stats['matched']}/{stats['total']} ({match_percent:.1f}%)")
368
+
369
+ # Check for specific important layers (conv1, layer1, etc.)
370
+ critical_layers = ['conv1', 'bn1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']
371
+ print("\nStatus of critical layers:")
372
+ for layer in critical_layers:
373
+ if layer in layer_matches:
374
+ match_percent = (layer_matches[layer]['matched'] / layer_matches[layer]['total']) * 100
375
+ status = "✅ COMPLETE" if layer_matches[layer]['matched'] == layer_matches[layer]['total'] else "⚠️ INCOMPLETE"
376
+ print(f" {layer}: {layer_matches[layer]['matched']}/{layer_matches[layer]['total']} ({match_percent:.1f}%) - {status}")
377
+ else:
378
+ print(f" {layer}: Not found in model")
379
+
380
+ # Load the ResNet state dict
381
+ if resnet_state_dict:
382
+ try:
383
+ # Use strict=False to allow missing keys
384
+ result = resnet.load_state_dict(resnet_state_dict, strict=False)
385
+ missing_keys, unexpected_keys = result
386
+
387
+ # Generate detailed information with better formatting
388
+ loading_report = []
389
+ loading_report.append(f"\n===== MODEL LOADING REPORT: {model_type} =====")
390
+ loading_report.append(f"Total parameters in checkpoint: {len(resnet_state_dict):,}")
391
+ loading_report.append(f"Total parameters in model: {len(resnet.state_dict()):,}")
392
+ loading_report.append(f"Missing keys: {len(missing_keys):,} parameters")
393
+ loading_report.append(f"Unexpected keys: {len(unexpected_keys):,} parameters")
394
+
395
+ # Calculate percentage of parameters loaded
396
+ loaded_keys = set(resnet_state_dict.keys()) - set(unexpected_keys)
397
+ loaded_percent = (len(loaded_keys) / len(resnet.state_dict())) * 100
398
+
399
+ # Determine loading success status
400
+ if loaded_percent >= 99.5:
401
+ status = "✅ COMPLETE - All important parameters loaded"
402
+ elif loaded_percent >= 90:
403
+ status = "🟡 PARTIAL - Most parameters loaded, should still function"
404
+ elif loaded_percent >= 50:
405
+ status = "⚠️ INCOMPLETE - Many parameters missing, may not function properly"
406
+ else:
407
+ status = "❌ FAILED - Critical parameters missing, will not function properly"
408
+
409
+ loading_report.append(f"Successfully loaded: {len(loaded_keys):,} parameters ({loaded_percent:.1f}%)")
410
+ loading_report.append(f"Loading status: {status}")
411
+
412
+ # If loading is severely incomplete, fall back to PyTorch's pretrained model
413
+ if loaded_percent < 50:
414
+ loading_report.append("\n⚠️ WARNING: Loading from checkpoint is too incomplete.")
415
+ loading_report.append("⚠️ Falling back to PyTorch's pretrained model to avoid broken inference.")
416
+
417
+ # Create a new ResNet model with pretrained weights
418
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
419
+ model = nn.Sequential(self.normalizer, resnet)
420
+ loading_report.append("✅ Successfully loaded PyTorch's pretrained ResNet50 model")
421
+
422
+ # Show missing keys by layer type
423
+ if missing_keys:
424
+ loading_report.append("\nMissing keys by layer type:")
425
+ layer_types = {}
426
+ for key in missing_keys:
427
+ # Extract layer type (e.g., 'conv', 'bn', 'layer1', etc.)
428
+ parts = key.split('.')
429
+ if len(parts) > 0:
430
+ layer_type = parts[0]
431
+ if layer_type not in layer_types:
432
+ layer_types[layer_type] = 0
433
+ layer_types[layer_type] += 1
434
+
435
+ # Add counts by layer type
436
+ for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
437
+ loading_report.append(f" {layer_type}: {count:,} parameters")
438
+
439
+ loading_report.append("\nFirst 10 missing keys:")
440
+ for i, key in enumerate(sorted(missing_keys)[:10]):
441
+ loading_report.append(f" {i+1}. {key}")
442
+
443
+ # Show unexpected keys if any
444
+ if unexpected_keys:
445
+ loading_report.append("\nFirst 10 unexpected keys:")
446
+ for i, key in enumerate(sorted(unexpected_keys)[:10]):
447
+ loading_report.append(f" {i+1}. {key}")
448
+
449
+ loading_report.append("========================================")
450
+
451
+ # Convert report to string and print it
452
+ report_text = "\n".join(loading_report)
453
+ print(report_text)
454
+
455
+ # Also save to a file for reference
456
+ os.makedirs("logs", exist_ok=True)
457
+ with open(f"logs/model_loading_{model_type}.log", "w") as f:
458
+ f.write(report_text)
459
+
460
+ # Look for normalizer parameters as well
461
+ if any(key.startswith('attacker.normalize.') for key in state_dict.keys()):
462
+ norm_state_dict = {}
463
+ for key, value in state_dict.items():
464
+ if key.startswith('attacker.normalize.'):
465
+ norm_key = key[len('attacker.normalize.'):]
466
+ norm_state_dict[norm_key] = value
467
+
468
+ if norm_state_dict:
469
+ try:
470
+ self.normalizer.load_state_dict(norm_state_dict, strict=False)
471
+ print("Successfully loaded normalizer parameters")
472
+ except Exception as e:
473
+ print(f"Warning: Could not load normalizer parameters: {e}")
474
+ except Exception as e:
475
+ print(f"Warning: Error loading ResNet parameters: {e}")
476
+ # Fall back to loading without normalizer
477
+ model = resnet # Use just the ResNet model without normalizer
478
+ except Exception as e:
479
+ print(f"Error loading model checkpoint: {e}")
480
+ # Fallback to PyTorch's pretrained model
481
+ print("Falling back to PyTorch's pretrained model")
482
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
483
+ model = nn.Sequential(self.normalizer, resnet)
484
  else:
485
  # Fallback to PyTorch's pretrained model
486
+ print("No checkpoint available, using PyTorch's pretrained model")
487
+ resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
488
+ model = nn.Sequential(self.normalizer, resnet)
489
 
490
  model = model.to(device)
491
  model.eval() # Set to evaluation mode
492
 
493
+ # Verify model integrity
494
+ self.verify_model_integrity(model, model_type)
495
+
496
  # Store the model for future use
497
  self.models[model_type] = model
498
  return model
499
+
500
  def inference(self, image, model_type, config):
501
+ """Run generative inference on the image."""
502
  # Load model if not already loaded
503
  model = self.load_model(model_type)
504
 
 
513
  image_tensor = transform(image).unsqueeze(0).to(device)
514
  image_tensor.requires_grad = True
515
 
516
+ # Check model structure
517
+ is_sequential = isinstance(model, nn.Sequential)
518
 
519
  # Get original predictions
520
  with torch.no_grad():
521
+ # If the model is sequential with a normalizer, skip the normalization step
522
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
523
+ print("Model is sequential with normalization")
524
+ output_original = model(image_tensor) # Model includes normalization
525
+ # Get the core model part (typically at index 1 in Sequential)
526
+ core_model = model[1]
527
+ else:
528
+ print("Model is not sequential with normalization")
529
+ # Use manual normalization for non-sequential models
530
+ normalized_tensor = normalize_transform(image_tensor)
531
+ output_original = model(normalized_tensor)
532
+ core_model = model
533
+
534
  probs_orig = F.softmax(output_original, dim=1)
535
  conf_orig, classes_orig = torch.max(probs_orig, 1)
536
 
 
541
  infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
542
 
543
  # Storage for inference steps
544
+ # Create a new tensor that requires gradients
545
+ x = image_tensor.clone().detach().requires_grad_(True)
546
  all_steps = [image_tensor[0].detach().cpu()]
547
 
548
  # Main inference loop
 
550
  # Reset gradients
551
  x.grad = None
552
 
 
 
 
553
  # Forward pass
554
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
555
+ output = model(x) # Model includes normalization
556
+ else:
557
+ # Use manual normalization for non-sequential models
558
+ normalized_x = normalize_transform(x)
559
+ output = model(normalized_x)
560
 
561
  # Calculate loss to maximize confidence for least confident classes
562
+ try:
563
+ # Get the least confident classes
564
+ num_classes = min(10, least_confident_classes.size(1))
565
+ target_classes = least_confident_classes[0, :num_classes]
566
+
567
+ # Create a combined loss (avoid accumulating in a loop)
568
+ targets = torch.tensor([idx.item() for idx in target_classes], device=device)
569
+
570
+ # Method 1: Use a single combined loss
571
+ loss = 0
572
+ for target in targets:
573
+ # Create one-hot target
574
+ one_hot = torch.zeros_like(output)
575
+ one_hot[0, target] = 1
576
+ # Use negative loss to maximize confidence
577
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
578
+
579
+ # Method 2: Try direct gradient calculation
580
+ # Instead of loss.backward(), which might be failing
581
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
582
+
583
+ if grad is None:
584
+ print("Warning: Direct gradient calculation failed")
585
+ # Fall back to random perturbation
586
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
587
+ x = x + random_noise
588
+ else:
589
+ # Update image with gradient
590
+ step = infer_step.step(x, grad)
591
+ x = x + step
592
+
593
+ x = infer_step.project(x)
594
+
595
+ except Exception as e:
596
+ print(f"Error in gradient calculation: {e}")
597
+ # Fall back to random perturbation
598
+ random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
599
+ x = x + random_noise
600
  x = infer_step.project(x)
601
 
602
  # Store step if in iterations_to_show
603
  if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
604
  all_steps.append(x[0].detach().cpu())
605
+
606
+ # Print some info about the inference
607
+ with torch.no_grad():
608
+ if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
609
+ final_output = model(x)
610
+ else:
611
+ normalized_x = normalize_transform(x)
612
+ final_output = model(normalized_x)
613
 
614
+ final_probs = F.softmax(final_output, dim=1)
615
+ final_conf, final_classes = torch.max(final_probs, 1)
616
+ print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
617
+ print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
618
+
619
  # Return final image and all stored steps
620
  return x[0].detach().cpu(), all_steps
621