import gradio as gr import torch import uuid import spaces from supermariogpt.dataset import MarioDataset from supermariogpt.prompter import Prompter from supermariogpt.lm import MarioLM from supermariogpt.utils import view_level, convert_level_to_png from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles import os import uvicorn from pathlib import Path import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize model try: mario_lm = MarioLM() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') mario_lm = mario_lm.to(device) logger.info(f"Model loaded successfully on {device}") except Exception as e: logger.error(f"Failed to load model: {e}") raise TILE_DIR = "data/tiles" # Ensure static directory exists Path("static").mkdir(exist_ok=True) gr.set_static_paths(paths=[Path("static").absolute()]) app = FastAPI() def make_html_file(generated_level): """Generate HTML file for level visualization""" try: level_text = f"""{''' '''.join(view_level(generated_level, mario_lm.tokenizer))}""" unique_id = uuid.uuid4() # Changed from uuid1 to uuid4 for better randomness html_filename = f"demo-{unique_id}.html" html_content = f''' supermariogpt ''' with open(Path("static") / html_filename, 'w', encoding='utf-8') as f: f.write(html_content) return html_filename except Exception as e: logger.error(f"Error creating HTML file: {e}") raise @spaces.GPU def generate(pipes, enemies, blocks, elevation, temperature=2.0, level_size=1399, prompt="", progress=gr.Progress(track_tqdm=True)): """Generate Mario level based on parameters""" try: # Validate inputs temperature = max(0.1, min(2.0, float(temperature))) level_size = max(100, min(2799, int(level_size))) if prompt == "": prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" logger.info(f"Using prompt: {prompt}") logger.info(f"Using temperature: {temperature}") logger.info(f"Using level size: {level_size}") prompts = [prompt] generated_level = mario_lm.sample( prompts=prompts, num_steps=level_size, temperature=float(temperature), use_tqdm=True ) filename = make_html_file(generated_level) img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] gradio_html = f'''

Press the arrow keys to move. Press a to run, s to jump and d to shoot fireflowers

''' return [img, gradio_html] except Exception as e: logger.error(f"Error generating level: {e}") raise gr.Error(f"Failed to generate level: {str(e)}") with gr.Blocks().queue() as demo: gr.Markdown('''# MarioGPT ### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)] ''') with gr.Tabs(): with gr.TabItem("Compose prompt"): with gr.Row(): pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?") enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?") with gr.Row(): blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?") elevation = gr.Radio(["low", "high"], value="low", label="Elevation?") with gr.TabItem("Type prompt"): text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations") level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size") btn = gr.Button("Generate level") with gr.Row(): with gr.Group(): level_play = gr.HTML() level_image = gr.Image() btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play]) gr.Examples( examples=[ ["many", "many", "some", "high"], ["no", "some", "many", "high"], ["many", "many", "little", "low"], ["no", "no", "many", "high"], ], inputs=[pipes, enemies, blocks, elevation], outputs=[level_image, level_play], fn=generate, cache_examples=True, ) # Mount static files and Gradio app app.mount("/static", StaticFiles(directory="static", html=True), name="static") app = gr.mount_gradio_app(app, demo, "/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)