Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import uuid | |
| import spaces | |
| from supermariogpt.dataset import MarioDataset | |
| from supermariogpt.prompter import Prompter | |
| from supermariogpt.lm import MarioLM | |
| from supermariogpt.utils import view_level, convert_level_to_png | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| import os | |
| import uvicorn | |
| from pathlib import Path | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize model | |
| try: | |
| mario_lm = MarioLM() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| mario_lm = mario_lm.to(device) | |
| logger.info(f"Model loaded successfully on {device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| TILE_DIR = "data/tiles" | |
| # Ensure static directory exists | |
| Path("static").mkdir(exist_ok=True) | |
| gr.set_static_paths(paths=[Path("static").absolute()]) | |
| app = FastAPI() | |
| def make_html_file(generated_level): | |
| """Generate HTML file for level visualization""" | |
| try: | |
| level_text = f"""{''' | |
| '''.join(view_level(generated_level, mario_lm.tokenizer))}""" | |
| unique_id = uuid.uuid4() # Changed from uuid1 to uuid4 for better randomness | |
| html_filename = f"demo-{unique_id}.html" | |
| html_content = f'''<!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <title>supermariogpt</title> | |
| <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script> | |
| </head> | |
| <body> | |
| </body> | |
| <script> | |
| cheerpjInit().then(function () {{ | |
| cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`); | |
| }}); | |
| cheerpjCreateDisplay(512, 500); | |
| cheerpjRunJar("/app/gradio_api/file=static/mario.jar"); | |
| </script> | |
| </html>''' | |
| with open(Path("static") / html_filename, 'w', encoding='utf-8') as f: | |
| f.write(html_content) | |
| return html_filename | |
| except Exception as e: | |
| logger.error(f"Error creating HTML file: {e}") | |
| raise | |
| def generate(pipes, enemies, blocks, elevation, temperature=2.0, level_size=1399, prompt="", progress=gr.Progress(track_tqdm=True)): | |
| """Generate Mario level based on parameters""" | |
| try: | |
| # Validate inputs | |
| temperature = max(0.1, min(2.0, float(temperature))) | |
| level_size = max(100, min(2799, int(level_size))) | |
| if prompt == "": | |
| prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation" | |
| logger.info(f"Using prompt: {prompt}") | |
| logger.info(f"Using temperature: {temperature}") | |
| logger.info(f"Using level size: {level_size}") | |
| prompts = [prompt] | |
| generated_level = mario_lm.sample( | |
| prompts=prompts, | |
| num_steps=level_size, | |
| temperature=float(temperature), | |
| use_tqdm=True | |
| ) | |
| filename = make_html_file(generated_level) | |
| img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0] | |
| gradio_html = f'''<div> | |
| <iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe> | |
| <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p> | |
| </div>''' | |
| return [img, gradio_html] | |
| except Exception as e: | |
| logger.error(f"Error generating level: {e}") | |
| raise gr.Error(f"Failed to generate level: {str(e)}") | |
| with gr.Blocks().queue() as demo: | |
| gr.Markdown('''# MarioGPT | |
| ### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models | |
| [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)] | |
| ''') | |
| with gr.Tabs(): | |
| with gr.TabItem("Compose prompt"): | |
| with gr.Row(): | |
| pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?") | |
| enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?") | |
| with gr.Row(): | |
| blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?") | |
| elevation = gr.Radio(["low", "high"], value="low", label="Elevation?") | |
| with gr.TabItem("Type prompt"): | |
| text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'") | |
| with gr.Accordion(label="Advanced settings", open=False): | |
| temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations") | |
| level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size") | |
| btn = gr.Button("Generate level") | |
| with gr.Row(): | |
| with gr.Group(): | |
| level_play = gr.HTML() | |
| level_image = gr.Image() | |
| btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play]) | |
| gr.Examples( | |
| examples=[ | |
| ["many", "many", "some", "high"], | |
| ["no", "some", "many", "high"], | |
| ["many", "many", "little", "low"], | |
| ["no", "no", "many", "high"], | |
| ], | |
| inputs=[pipes, enemies, blocks, elevation], | |
| outputs=[level_image, level_play], | |
| fn=generate, | |
| cache_examples=True, | |
| ) | |
| # Mount static files and Gradio app | |
| app.mount("/static", StaticFiles(directory="static", html=True), name="static") | |
| app = gr.mount_gradio_app(app, demo, "/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |