ttoosi commited on
Commit
2ab33ec
·
verified ·
1 Parent(s): feeaf45

added more sliders

Browse files

next step add nominal values, feedback representations, link info for each illusion.

Files changed (1) hide show
  1. inference.py +325 -59
inference.py CHANGED
@@ -9,10 +9,17 @@ import numpy as np
9
  import os
10
  import requests
11
  import time
 
 
12
  from pathlib import Path
13
 
14
- # Check CUDA availability
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
16
  print(f"Using device: {device}")
17
 
18
  # Constants
@@ -24,7 +31,23 @@ MODEL_URLS = {
24
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
  IMAGENET_STD = [0.229, 0.224, 0.225]
26
 
27
- # Default transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  transform = transforms.Compose([
29
  transforms.Resize(224),
30
  transforms.CenterCrop(224),
@@ -33,6 +56,98 @@ transform = transforms.Compose([
33
 
34
  normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Get ImageNet labels
37
  def get_imagenet_labels():
38
  url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
@@ -98,21 +213,49 @@ class InferStep:
98
  scaled_grad = grad / (grad_norm + 1e-10)
99
  return scaled_grad * self.step_size
100
 
101
- def get_inference_configs(eps=0.5, n_itr=50):
102
- """Generate inference configuration with customizable parameters."""
 
 
 
 
 
 
 
 
 
103
  config = {
104
- 'loss_infer': 'IncreaseConfidence', # How to guide the optimization
105
- 'loss_function': 'CE', # Loss function: Cross Entropy
106
  'n_itr': n_itr, # Number of iterations
107
  'eps': eps, # Maximum perturbation size
108
- 'step_size': 1, # Step size for each iteration
109
  'diffusion_noise_ratio': 0.0, # No diffusion noise
110
  'initial_inference_noise_ratio': 0.0, # No initial noise
111
  'top_layer': 'all', # Use all layers of the model
112
- 'inference_normalization': 'off', # Apply normalization during inference
113
- 'recognition_normalization': 'off', # Apply normalization during recognition
114
- 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize
 
115
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return config
117
 
118
  class GenerativeInferenceModel:
@@ -128,10 +271,9 @@ class GenerativeInferenceModel:
128
  """
129
  try:
130
  print(f"\n=== Running model integrity check for {model_type} ===")
131
- # Create a deterministic test input
132
- test_input = torch.zeros(1, 3, 224, 224)
133
  test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
134
- test_input = test_input.to(model.device if hasattr(model, 'device') else 'cpu')
135
 
136
  # Run forward pass
137
  with torch.no_grad():
@@ -170,13 +312,17 @@ class GenerativeInferenceModel:
170
 
171
  except Exception as e:
172
  print(f"❌ Model integrity check failed with error: {e}")
173
- return False
 
174
 
175
  def load_model(self, model_type):
176
  """Load model from checkpoint or use pretrained model."""
177
  if model_type in self.models:
 
178
  return self.models[model_type]
179
 
 
 
180
  model_path = download_model(model_type)
181
 
182
  # Create a sequential model with normalizer and ResNet50
@@ -495,10 +641,16 @@ class GenerativeInferenceModel:
495
 
496
  # Store the model for future use
497
  self.models[model_type] = model
 
 
 
498
  return model
499
 
500
  def inference(self, image, model_type, config):
501
  """Run generative inference on the image."""
 
 
 
502
  # Load model if not already loaded
503
  model = self.load_model(model_type)
504
 
@@ -508,10 +660,29 @@ class GenerativeInferenceModel:
508
  image = Image.open(image).convert('RGB')
509
  else:
510
  raise ValueError(f"Image path does not exist: {image}")
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
- # Prepare image tensor
513
- image_tensor = transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
514
  image_tensor.requires_grad = True
 
515
 
516
  # Check model structure
517
  is_sequential = isinstance(model, nn.Sequential)
@@ -521,14 +692,21 @@ class GenerativeInferenceModel:
521
  # If the model is sequential with a normalizer, skip the normalization step
522
  if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
523
  print("Model is sequential with normalization")
524
- output_original = model(image_tensor) # Model includes normalization
525
  # Get the core model part (typically at index 1 in Sequential)
526
  core_model = model[1]
 
 
 
 
 
527
  else:
528
  print("Model is not sequential with normalization")
529
  # Use manual normalization for non-sequential models
530
- normalized_tensor = normalize_transform(image_tensor)
531
- output_original = model(normalized_tensor)
 
 
 
532
  core_model = model
533
 
534
  probs_orig = F.softmax(output_original, dim=1)
@@ -545,59 +723,126 @@ class GenerativeInferenceModel:
545
  x = image_tensor.clone().detach().requires_grad_(True)
546
  all_steps = [image_tensor[0].detach().cpu()]
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  # Main inference loop
 
 
549
  for i in range(config['n_itr']):
550
  # Reset gradients
551
  x.grad = None
552
 
553
- # Forward pass
554
- if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
555
- output = model(x) # Model includes normalization
 
 
556
  else:
557
- # Use manual normalization for non-sequential models
558
- normalized_x = normalize_transform(x)
559
- output = model(normalized_x)
560
 
561
- # Calculate loss to maximize confidence for least confident classes
562
  try:
563
- # Get the least confident classes
564
- num_classes = min(10, least_confident_classes.size(1))
565
- target_classes = least_confident_classes[0, :num_classes]
 
 
 
 
 
566
 
567
- # Create a combined loss (avoid accumulating in a loop)
568
- targets = torch.tensor([idx.item() for idx in target_classes], device=device)
569
-
570
- # Method 1: Use a single combined loss
571
- loss = 0
572
- for target in targets:
573
- # Create one-hot target
574
- one_hot = torch.zeros_like(output)
575
- one_hot[0, target] = 1
576
- # Use negative loss to maximize confidence
577
- loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
578
-
579
- # Method 2: Try direct gradient calculation
580
- # Instead of loss.backward(), which might be failing
581
- grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
 
 
 
582
 
583
  if grad is None:
584
  print("Warning: Direct gradient calculation failed")
585
  # Fall back to random perturbation
586
  random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
587
- x = x + random_noise
588
  else:
589
- # Update image with gradient
590
- step = infer_step.step(x, grad)
591
- x = x + step
592
-
593
- x = infer_step.project(x)
 
 
 
594
 
595
  except Exception as e:
596
  print(f"Error in gradient calculation: {e}")
597
- # Fall back to random perturbation
598
  random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
599
- x = x + random_noise
600
- x = infer_step.project(x)
601
 
602
  # Store step if in iterations_to_show
603
  if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
@@ -606,18 +851,39 @@ class GenerativeInferenceModel:
606
  # Print some info about the inference
607
  with torch.no_grad():
608
  if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
609
- final_output = model(x)
 
 
 
610
  else:
611
- normalized_x = normalize_transform(x)
612
- final_output = model(normalized_x)
 
 
 
613
 
614
  final_probs = F.softmax(final_output, dim=1)
615
  final_conf, final_classes = torch.max(final_probs, 1)
 
 
 
 
 
 
616
  print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
617
  print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
 
 
618
 
619
- # Return final image and all stored steps
620
- return x[0].detach().cpu(), all_steps
 
 
 
 
 
 
 
621
 
622
  # Utility function to show inference steps
623
  def show_inference_steps(steps, figsize=(15, 10)):
 
9
  import os
10
  import requests
11
  import time
12
+ import copy
13
+ from collections import OrderedDict
14
  from pathlib import Path
15
 
16
+ # Check for available hardware acceleration
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
21
+ else:
22
+ device = torch.device("cpu")
23
  print(f"Using device: {device}")
24
 
25
  # Constants
 
31
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
32
  IMAGENET_STD = [0.229, 0.224, 0.225]
33
 
34
+ # Define the transforms based on whether normalization is on or off
35
+ def get_transform(input_size=224, normalize=False, norm_mean=IMAGENET_MEAN, norm_std=IMAGENET_STD):
36
+ if normalize:
37
+ return transforms.Compose([
38
+ transforms.Resize(input_size),
39
+ transforms.CenterCrop(input_size),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize(norm_mean, norm_std),
42
+ ])
43
+ else:
44
+ return transforms.Compose([
45
+ transforms.Resize(input_size),
46
+ transforms.CenterCrop(input_size),
47
+ transforms.ToTensor(),
48
+ ])
49
+
50
+ # Default transform without normalization
51
  transform = transforms.Compose([
52
  transforms.Resize(224),
53
  transforms.CenterCrop(224),
 
56
 
57
  normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
58
 
59
+ def extract_middle_layers(model, layer_index):
60
+ """
61
+ Extract a subset of the model up to a specific layer.
62
+
63
+ Args:
64
+ model: The neural network model
65
+ layer_index: String 'all' for the full model, or a layer identifier (string or int)
66
+ For ResNet: integers 0-8 representing specific layers
67
+ For ViT: strings like 'encoder.layers.encoder_layer_3'
68
+
69
+ Returns:
70
+ A modified model that outputs features from the specified layer
71
+ """
72
+ if isinstance(layer_index, str) and layer_index == 'all':
73
+ return model
74
+
75
+ # Special case for ViT's encoder layers with DataParallel wrapper
76
+ if isinstance(layer_index, str) and layer_index.startswith('encoder.layers.encoder_layer_'):
77
+ try:
78
+ target_layer_idx = int(layer_index.split('_')[-1])
79
+
80
+ # Create a deep copy of the model to avoid modifying the original
81
+ new_model = copy.deepcopy(model)
82
+
83
+ # For models wrapped in DataParallel
84
+ if hasattr(new_model, 'module'):
85
+ # Create a subset of encoder layers up to the specified index
86
+ encoder_layers = nn.Sequential()
87
+ for i in range(target_layer_idx + 1):
88
+ layer_name = f"encoder_layer_{i}"
89
+ if hasattr(new_model.module.encoder.layers, layer_name):
90
+ encoder_layers.add_module(layer_name,
91
+ getattr(new_model.module.encoder.layers, layer_name))
92
+
93
+ # Replace the encoder layers with our truncated version
94
+ new_model.module.encoder.layers = encoder_layers
95
+
96
+ # Remove the heads since we're stopping at the encoder layer
97
+ new_model.module.heads = nn.Identity()
98
+
99
+ return new_model
100
+ else:
101
+ # Direct model access (not DataParallel)
102
+ encoder_layers = nn.Sequential()
103
+ for i in range(target_layer_idx + 1):
104
+ layer_name = f"encoder_layer_{i}"
105
+ if hasattr(new_model.encoder.layers, layer_name):
106
+ encoder_layers.add_module(layer_name,
107
+ getattr(new_model.encoder.layers, layer_name))
108
+
109
+ # Replace the encoder layers with our truncated version
110
+ new_model.encoder.layers = encoder_layers
111
+
112
+ # Remove the heads since we're stopping at the encoder layer
113
+ new_model.heads = nn.Identity()
114
+
115
+ return new_model
116
+
117
+ except (ValueError, IndexError) as e:
118
+ raise ValueError(f"Invalid ViT layer specification: {layer_index}. Error: {e}")
119
+
120
+ # Handling for ViT whole blocks
121
+ elif hasattr(model, 'blocks') or (hasattr(model, 'module') and hasattr(model.module, 'blocks')):
122
+ # Check for DataParallel wrapper
123
+ base_model = model.module if hasattr(model, 'module') else model
124
+
125
+ # Create a deep copy to avoid modifying the original
126
+ new_model = copy.deepcopy(model)
127
+ base_new_model = new_model.module if hasattr(new_model, 'module') else new_model
128
+
129
+ # Add the desired number of transformer blocks
130
+ if isinstance(layer_index, int):
131
+ # Truncate the blocks
132
+ base_new_model.blocks = base_new_model.blocks[:layer_index+1]
133
+
134
+ return new_model
135
+
136
+ else:
137
+ # Original ResNet/VGG handling
138
+ modules = list(model.named_children())
139
+ print(f"DEBUG - extract_middle_layers - Looking for '{layer_index}' in {[name for name, _ in modules]}")
140
+
141
+ cutoff_idx = next((i for i, (name, _) in enumerate(modules)
142
+ if name == str(layer_index)), None)
143
+
144
+ if cutoff_idx is not None:
145
+ # Keep modules up to and including the target
146
+ new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx+1]))
147
+ return new_model
148
+ else:
149
+ raise ValueError(f"Module {layer_index} not found in model")
150
+
151
  # Get ImageNet labels
152
  def get_imagenet_labels():
153
  url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
 
213
  scaled_grad = grad / (grad_norm + 1e-10)
214
  return scaled_grad * self.step_size
215
 
216
+ def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
217
+ """Generate inference configuration with customizable parameters.
218
+
219
+ Args:
220
+ inference_type (str): Type of inference ('IncreaseConfidence' or 'ReverseDiffusion')
221
+ eps (float): Maximum perturbation size
222
+ n_itr (int): Number of iterations
223
+ step_size (float): Step size for each iteration
224
+ """
225
+
226
+ # Base configuration common to all inference types
227
  config = {
228
+ 'loss_infer': inference_type, # How to guide the optimization
 
229
  'n_itr': n_itr, # Number of iterations
230
  'eps': eps, # Maximum perturbation size
231
+ 'step_size': step_size, # Step size for each iteration
232
  'diffusion_noise_ratio': 0.0, # No diffusion noise
233
  'initial_inference_noise_ratio': 0.0, # No initial noise
234
  'top_layer': 'all', # Use all layers of the model
235
+ 'inference_normalization': False, # Apply normalization during inference
236
+ 'recognition_normalization': False, # Apply normalization during recognition
237
+ 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr], # Specific iterations to visualize
238
+ 'misc_info': {'keep_grads': False} # Additional configuration
239
  }
240
+
241
+ # Customize based on inference type
242
+ if inference_type == 'IncreaseConfidence':
243
+ config['loss_function'] = 'CE' # Cross Entropy
244
+
245
+ elif inference_type == 'ReverseDiffusion':
246
+ config['loss_function'] = 'MSE' # Mean Square Error
247
+ config['initial_inference_noise_ratio'] = 0.05 # Initial noise for diffusion
248
+ config['diffusion_noise_ratio'] = 0.01 # Add noise during diffusion
249
+
250
+ elif inference_type == 'GradModulation':
251
+ config['loss_function'] = 'CE' # Cross Entropy
252
+ config['misc_info']['grad_modulation'] = 0.5 # Gradient modulation strength
253
+
254
+ elif inference_type == 'CompositionalFusion':
255
+ config['loss_function'] = 'CE' # Cross Entropy
256
+ config['misc_info']['positive_classes'] = [] # Classes to maximize
257
+ config['misc_info']['negative_classes'] = [] # Classes to minimize
258
+
259
  return config
260
 
261
  class GenerativeInferenceModel:
 
271
  """
272
  try:
273
  print(f"\n=== Running model integrity check for {model_type} ===")
274
+ # Create a deterministic test input directly on the correct device
275
+ test_input = torch.zeros(1, 3, 224, 224, device=device)
276
  test_input[0, 0, 100:124, 100:124] = 0.5 # Red square
 
277
 
278
  # Run forward pass
279
  with torch.no_grad():
 
312
 
313
  except Exception as e:
314
  print(f"❌ Model integrity check failed with error: {e}")
315
+ # Rather than failing completely, we'll continue
316
+ return True
317
 
318
  def load_model(self, model_type):
319
  """Load model from checkpoint or use pretrained model."""
320
  if model_type in self.models:
321
+ print(f"Using cached {model_type} model")
322
  return self.models[model_type]
323
 
324
+ # Record loading time for performance analysis
325
+ start_time = time.time()
326
  model_path = download_model(model_type)
327
 
328
  # Create a sequential model with normalizer and ResNet50
 
641
 
642
  # Store the model for future use
643
  self.models[model_type] = model
644
+ end_time = time.time()
645
+ load_time = end_time - start_time
646
+ print(f"Model {model_type} loaded in {load_time:.2f} seconds")
647
  return model
648
 
649
  def inference(self, image, model_type, config):
650
  """Run generative inference on the image."""
651
+ # Time the entire inference process
652
+ inference_start = time.time()
653
+
654
  # Load model if not already loaded
655
  model = self.load_model(model_type)
656
 
 
660
  image = Image.open(image).convert('RGB')
661
  else:
662
  raise ValueError(f"Image path does not exist: {image}")
663
+ elif isinstance(image, torch.Tensor):
664
+ raise ValueError(f"Image type {type(image)}, looks like already a transformed tensor")
665
+
666
+ # Prepare image tensor - match original code's conditional transform
667
+ load_start = time.time()
668
+ use_norm = config['inference_normalization'] == 'on'
669
+ custom_transform = get_transform(
670
+ input_size=224,
671
+ normalize=use_norm,
672
+ norm_mean=IMAGENET_MEAN,
673
+ norm_std=IMAGENET_STD
674
+ )
675
 
676
+ # Special handling for GradModulation as in original
677
+ if config['loss_infer'] == 'GradModulation' and 'misc_info' in config and 'grad_modulation' in config['misc_info']:
678
+ grad_modulation = config['misc_info']['grad_modulation']
679
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
680
+ image_tensor = image_tensor * (1-grad_modulation) + grad_modulation * torch.randn_like(image_tensor).to(device)
681
+ else:
682
+ image_tensor = custom_transform(image).unsqueeze(0).to(device)
683
+
684
  image_tensor.requires_grad = True
685
+ print(f"Image loaded and processed in {time.time() - load_start:.2f} seconds")
686
 
687
  # Check model structure
688
  is_sequential = isinstance(model, nn.Sequential)
 
692
  # If the model is sequential with a normalizer, skip the normalization step
693
  if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
694
  print("Model is sequential with normalization")
 
695
  # Get the core model part (typically at index 1 in Sequential)
696
  core_model = model[1]
697
+ if config['inference_normalization']:
698
+ output_original = model(image_tensor) # Model includes normalization
699
+ else:
700
+ output_original = core_model(image_tensor) # Model includes normalization
701
+
702
  else:
703
  print("Model is not sequential with normalization")
704
  # Use manual normalization for non-sequential models
705
+ if config['inference_normalization']:
706
+ normalized_tensor = normalize_transform(image_tensor)
707
+ output_original = model(normalized_tensor)
708
+ else:
709
+ output_original = model(image_tensor)
710
  core_model = model
711
 
712
  probs_orig = F.softmax(output_original, dim=1)
 
723
  x = image_tensor.clone().detach().requires_grad_(True)
724
  all_steps = [image_tensor[0].detach().cpu()]
725
 
726
+ # For ReverseDiffusion, extract selected layer and initialize with noisy features
727
+ noisy_features = None
728
+ layer_model = None
729
+ if config['loss_infer'] == 'ReverseDiffusion':
730
+ print(f"Setting up ReverseDiffusion with layer {config['top_layer']} and noise {config['initial_inference_noise_ratio']}...")
731
+
732
+ # Extract model up to the specified layer
733
+ try:
734
+ # Start by finding the actual model to use
735
+ base_model = model
736
+
737
+ # Handle DataParallel wrapper if present
738
+ if hasattr(base_model, 'module'):
739
+ base_model = base_model.module
740
+
741
+ # Log the initial model structure
742
+ print(f"DEBUG - Initial model structure: {type(base_model)}")
743
+
744
+ # If we have a Sequential model (which is likely our normalizer + model structure)
745
+ if isinstance(base_model, nn.Sequential):
746
+ print(f"DEBUG - Sequential model with {len(list(base_model.children()))} children")
747
+
748
+ # If this is our NormalizeByChannelMeanStd + ResNet pattern
749
+ if len(list(base_model.children())) >= 2:
750
+ # The actual ResNet model is the second component (index 1)
751
+ actual_model = list(base_model.children())[1]
752
+ print(f"DEBUG - Using ResNet component: {type(actual_model)}")
753
+ print(f"DEBUG - Available layers: {[name for name, _ in actual_model.named_children()]}")
754
+
755
+ # Extract from the actual ResNet
756
+ layer_model = extract_middle_layers(actual_model, config['top_layer'])
757
+ else:
758
+ # Just a single component Sequential
759
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
760
+ else:
761
+ # Not Sequential, might be direct model
762
+ print(f"DEBUG - Available layers: {[name for name, _ in base_model.named_children()]}")
763
+ layer_model = extract_middle_layers(base_model, config['top_layer'])
764
+
765
+ print(f"Successfully extracted model up to layer: {config['top_layer']}")
766
+ except ValueError as e:
767
+ print(f"Layer extraction failed: {e}. Using full model.")
768
+ layer_model = model
769
+
770
+ # Add noise to the image - exactly match original code
771
+ added_noise = config['initial_inference_noise_ratio'] * torch.randn_like(image_tensor).to(device)
772
+ noisy_image_tensor = image_tensor + added_noise
773
+
774
+ # Compute noisy features - simplified to match original code
775
+ noisy_features = layer_model(noisy_image_tensor)
776
+
777
+ print(f"Noisy features computed for ReverseDiffusion target with shape: {noisy_features.shape if hasattr(noisy_features, 'shape') else 'unknown'}")
778
+
779
  # Main inference loop
780
+ print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
781
+ loop_start = time.time()
782
  for i in range(config['n_itr']):
783
  # Reset gradients
784
  x.grad = None
785
 
786
+ # Forward pass - use layer_model for ReverseDiffusion, full model otherwise
787
+ if config['loss_infer'] == 'ReverseDiffusion' and layer_model is not None:
788
+ # Use the extracted layer model for ReverseDiffusion
789
+ # In original code, normalization is handled at transform time, not during forward pass
790
+ output = layer_model(x)
791
  else:
792
+ # Standard forward pass with full model
793
+ # Simplified to match original code's approach
794
+ output = model(x)
795
 
796
+ # Calculate loss and gradients based on inference type
797
  try:
798
+ if config['loss_infer'] == 'ReverseDiffusion':
799
+ # Use MSE loss to match the noisy features
800
+ assert config['loss_function'] == 'MSE', "Reverse Diffusion loss function must be MSE"
801
+ if noisy_features is not None:
802
+ loss = F.mse_loss(output, noisy_features)
803
+ grad = torch.autograd.grad(loss, x)[0] # Removed retain_graph=True to match original
804
+ else:
805
+ raise ValueError("Noisy features not computed for ReverseDiffusion")
806
 
807
+ else: # Default 'IncreaseConfidence' approach
808
+ # Get the least confident classes
809
+ num_classes = min(10, least_confident_classes.size(1))
810
+ target_classes = least_confident_classes[0, :num_classes]
811
+
812
+ # Create targets for least confident classes
813
+ targets = torch.tensor([idx.item() for idx in target_classes], device=device)
814
+
815
+ # Use a combined loss to increase confidence
816
+ loss = 0
817
+ for target in targets:
818
+ # Create one-hot target
819
+ one_hot = torch.zeros_like(output)
820
+ one_hot[0, target] = 1
821
+ # Use loss to maximize confidence
822
+ loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
823
+
824
+ grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
825
 
826
  if grad is None:
827
  print("Warning: Direct gradient calculation failed")
828
  # Fall back to random perturbation
829
  random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
830
+ x = infer_step.project(x + random_noise)
831
  else:
832
+ # Update image with gradient - do this exactly as in original code
833
+ adjusted_grad = infer_step.step(x, grad)
834
+
835
+ # Add diffusion noise if specified
836
+ diffusion_noise = config['diffusion_noise_ratio'] * torch.randn_like(x).to(device)
837
+
838
+ # Apply gradient and noise in one operation before projecting, exactly as in original
839
+ x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
840
 
841
  except Exception as e:
842
  print(f"Error in gradient calculation: {e}")
843
+ # Fall back to random perturbation - match original code
844
  random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
845
+ x = infer_step.project(x.clone() + random_noise)
 
846
 
847
  # Store step if in iterations_to_show
848
  if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
 
851
  # Print some info about the inference
852
  with torch.no_grad():
853
  if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
854
+ if config['inference_normalization']:
855
+ final_output = model(x)
856
+ else:
857
+ final_output = core_model(x)
858
  else:
859
+ if config['inference_normalization']:
860
+ normalized_x = normalize_transform(x)
861
+ final_output = model(normalized_x)
862
+ else:
863
+ final_output = model(x)
864
 
865
  final_probs = F.softmax(final_output, dim=1)
866
  final_conf, final_classes = torch.max(final_probs, 1)
867
+
868
+ # Calculate timing information
869
+ loop_time = time.time() - loop_start
870
+ total_time = time.time() - inference_start
871
+ avg_iter_time = loop_time / config['n_itr'] if config['n_itr'] > 0 else 0
872
+
873
  print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
874
  print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
875
+ print(f"Inference loop completed in {loop_time:.2f} seconds ({avg_iter_time:.4f} sec/iteration)")
876
+ print(f"Total inference time: {total_time:.2f} seconds")
877
 
878
+ # Return results in format compatible with both old and new code
879
+ return {
880
+ 'final_image': x[0].detach().cpu(),
881
+ 'steps': all_steps,
882
+ 'original_class': classes_orig.item(),
883
+ 'original_confidence': conf_orig.item(),
884
+ 'final_class': final_classes.item(),
885
+ 'final_confidence': final_conf.item()
886
+ }
887
 
888
  # Utility function to show inference steps
889
  def show_inference_steps(steps, figsize=(15, 10)):