import gradio as gr import subprocess import os import tempfile import json def generate_caption(image, epsilon, sparsity, attack_algo, num_iters): """ Generate caption for the uploaded image using the model in RobustMMFMEnv. Args: image: The uploaded image from Gradio Returns: tuple: (original_caption, adversarial_caption, original_image, adversarial_image, perturbation_image) """ if image is None: return "Please upload an image first.", "", None, None, None try: # Save the uploaded image to a temporary file with tempfile.NamedTemporaryFile(mode='wb', suffix='.jpg', delete=False) as tmp_file: tmp_image_path = tmp_file.name # Save the image from PIL import Image import numpy as np if isinstance(image, np.ndarray): img = Image.fromarray(image) img.save(tmp_image_path) else: image.save(tmp_image_path) # Prepare the command to run in RobustMMFMEnv # This is a placeholder - you'll need to create the actual script conda_env = "RobustMMFMEnv" script_path = os.path.join(os.path.dirname(__file__), "run_caption.py") # Run the caption generation script in the RobustMMFMEnv conda environment cmd = [ "conda", "run", "-n", conda_env, "python", script_path, "--image_path", tmp_image_path, "--epsilon", str(epsilon), "--num_iters", str(num_iters), "--sparsity", str(sparsity), "--attack_algo", attack_algo ] result = subprocess.run( cmd, capture_output=True, text=True, timeout=60 # 60 seconds timeout ) # Clean up temporary file os.unlink(tmp_image_path) if result.returncode == 0: # Parse the output output = result.stdout.strip() #return output if output else "No caption generated." try: # Parse the dictionary output import ast result_dict = ast.literal_eval(output) original = result_dict.get('original_caption', '').strip() adversarial = result_dict.get('adversarial_caption', '').strip() orig_img_path = result_dict.get('original_image_path') adv_img_path = result_dict.get('adversarial_image_path') pert_img_path = result_dict.get('perturbation_image_path') orig_image = None adv_image = None pert_image = None if orig_img_path and os.path.exists(orig_img_path): orig_image = np.array(Image.open(orig_img_path)) try: os.unlink(orig_img_path) except: pass if adv_img_path and os.path.exists(adv_img_path): adv_image = np.array(Image.open(adv_img_path)) try: os.unlink(adv_img_path) except: pass if pert_img_path and os.path.exists(pert_img_path): pert_image = np.array(Image.open(pert_img_path)) try: os.unlink(pert_img_path) except: pass return original, adversarial, orig_image, adv_image, pert_image # Return 5 values except (ValueError, SyntaxError) as e: print(f"Failed to parse output: {e}", flush=True) # If parsing fails, try to return raw output return f"Parse error: {str(e)}", "", None, None, None else: error_msg = result.stderr.strip() return f"Error generating caption: {error_msg}", "", None, None, None except subprocess.TimeoutExpired: return "Error: Caption generation timed out (>60s)", "", None, None, None except Exception as e: return f"Error: {str(e)}", "", None, None, None # Create the Gradio interface with gr.Blocks(title="Image Captioning") as demo: gr.Markdown("# Evaluating Robustness of Multimodal Models Against Adversarial Perturbations") gr.Markdown("Upload an image to generate the adversarial image and caption using the APGD/SAIF algorithm.") with gr.Row(): with gr.Column(): image_input = gr.Image( label="Upload Image", type="numpy" ) attack_algo = gr.Dropdown( choices=["APGD", "SAIF"], value="APGD", label="Adversarial Attack Algorithm", interactive=True ) epsilon = gr.Slider( minimum=1, maximum=255, value=8, step=1, interactive=True, label="Epsilon (max perturbation, 0-255 scale)" ) sparsity = gr.Slider( minimum=0, maximum=10000, value=0, step=100, interactive=True, label="Sparsity (L1 norm of the perturbation, for SAIF only)" ) num_iters = gr.Slider( minimum=1, maximum=100, value=8, step=1, interactive=True, label="Number of Iterations" ) with gr.Row(): with gr.Column(): generate_btn = gr.Button("Generate Captions", variant="primary") with gr.Row(): with gr.Column(): orig_image_output = gr.Image(label="Original Image") orig_caption_output = gr.Textbox( label="Generated Original Caption", lines=5, placeholder="Caption will appear here..." ) with gr.Column(): pert_image_output = gr.Image(label="Perturbation (10x magnified)") with gr.Column(): adv_image_output = gr.Image(label="Adversarial Image") adv_caption_output = gr.Textbox( label="Generated Adversarial Caption", lines=5, placeholder="Caption will appear here..." ) # Set up the button click event generate_btn.click( fn=generate_caption, inputs=[image_input, epsilon, sparsity, attack_algo, num_iters], outputs=[orig_caption_output, adv_caption_output, orig_image_output, adv_image_output, pert_image_output] ) if __name__ == "__main__": # Use environment variable or find an available port port = int(os.environ.get("GRADIO_SERVER_PORT", "7861")) demo.launch( server_name="0.0.0.0", server_port=port, share=True, debug=True, show_error=True )