File size: 4,634 Bytes
f7dbbd9
7449d44
 
 
 
 
 
f7dbbd9
7449d44
 
 
 
f7dbbd9
7449d44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
import numpy as np
from PIL import Image
import os
import argparse
from inference import GenerativeInferenceModel, get_inference_configs

# Parse command line arguments
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
args = parser.parse_args()

# Create model directories if they don't exist
os.makedirs("models", exist_ok=True)
os.makedirs("stimuli", exist_ok=True)

# Initialize model
model = GenerativeInferenceModel()

def run_inference(image, model_type, illusion_type, eps_value, num_iterations):
    # Convert eps to float
    eps = float(eps_value)
    
    # Load inference configuration
    config = get_inference_configs(eps=eps, n_itr=int(num_iterations))
    
    # Run generative inference
    output_images, all_steps = model.inference(image, model_type, config)
    
    # Create animation frames
    frames = []
    for i, step_image in enumerate(all_steps):
        # Convert tensor to PIL image
        step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
        frames.append(step_pil)
    
    # Return the final inferred image and the animation
    return output_images, gr.Gallery.update(value=frames)

# Define the interface
with gr.Blocks(title="Generative Inference Demo") as demo:
    gr.Markdown("# Generative Inference Demo")
    gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Inputs
            image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil")
            
            with gr.Row():
                model_choice = gr.Dropdown(
                    choices=["robust_resnet50", "standard_resnet50"], 
                    value="robust_resnet50", 
                    label="Model"
                )
                
                illusion_type = gr.Dropdown(
                    choices=["Kanizsa", "Face-Vase", "Neon-Color", "Figure-Ground"], 
                    value="Kanizsa", 
                    label="Illusion Type"
                )
            
            with gr.Row():
                eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)")
                iterations_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Number of Iterations")
            
            run_button = gr.Button("Run Inference")
            
        with gr.Column(scale=2):
            # Outputs
            output_image = gr.Image(label="Final Inferred Image")
            output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2)
    
    # Set up example images
    examples = [
        [os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "Kanizsa", 0.5, 50],
        [os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "Face-Vase", 0.5, 50],
        [os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "Figure-Ground", 0.7, 100],
        [os.path.join("stimuli", "NeonColorSaeedi.jpg"), "robust_resnet50", "Neon-Color", 0.3, 80]
    ]
    
    gr.Examples(examples=examples, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider])
    
    # Set up event handler
    run_button.click(
        fn=run_inference,
        inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider],
        outputs=[output_image, output_frames]
    )
    
    # Include a description of the technique
    gr.Markdown("""
    ## About Generative Inference
    
    Generative inference is a technique that reveals how neural networks perceive visual stimuli by optimizing the input
    to increase the network's confidence in its predictions. This process can reveal emergent perception of contours,
    figure-ground separation, and other visual phenomena similar to human perception.
    
    This demo allows you to:
    1. Upload your own images or select from example illusions
    2. Choose between robust or standard models
    3. Adjust parameters like perturbation size (epsilon) and number of iterations
    4. Visualize how the perception emerges over time
    """)

# Launch the demo with specific settings
if __name__ == "__main__":
    print(f"Starting server on port {args.port}")
    demo.launch(
        server_name="0.0.0.0",  # Listen on all interfaces
        server_port=args.port,  # Use the port from command line arguments
        share=False,
        debug=True
    )