SuperMarioGPT / app.py
DarkDriftz's picture
Update app.py
de878f4 verified
import gradio as gr
import torch
import uuid
import spaces
from supermariogpt.dataset import MarioDataset
from supermariogpt.prompter import Prompter
from supermariogpt.lm import MarioLM
from supermariogpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
from pathlib import Path
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize model
try:
mario_lm = MarioLM()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mario_lm = mario_lm.to(device)
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
TILE_DIR = "data/tiles"
# Ensure static directory exists
Path("static").mkdir(exist_ok=True)
gr.set_static_paths(paths=[Path("static").absolute()])
app = FastAPI()
def make_html_file(generated_level):
"""Generate HTML file for level visualization"""
try:
level_text = f"""{'''
'''.join(view_level(generated_level, mario_lm.tokenizer))}"""
unique_id = uuid.uuid4() # Changed from uuid1 to uuid4 for better randomness
html_filename = f"demo-{unique_id}.html"
html_content = f'''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>supermariogpt</title>
<script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>
<body>
</body>
<script>
cheerpjInit().then(function () {{
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
}});
cheerpjCreateDisplay(512, 500);
cheerpjRunJar("/app/gradio_api/file=static/mario.jar");
</script>
</html>'''
with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
f.write(html_content)
return html_filename
except Exception as e:
logger.error(f"Error creating HTML file: {e}")
raise
@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature=2.0, level_size=1399, prompt="", progress=gr.Progress(track_tqdm=True)):
"""Generate Mario level based on parameters"""
try:
# Validate inputs
temperature = max(0.1, min(2.0, float(temperature)))
level_size = max(100, min(2799, int(level_size)))
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
logger.info(f"Using prompt: {prompt}")
logger.info(f"Using temperature: {temperature}")
logger.info(f"Using level size: {level_size}")
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=float(temperature),
use_tqdm=True
)
filename = make_html_file(generated_level)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = f'''<div>
<iframe width=512 height=512 style="margin: 0 auto" src="/gradio_api/file=static/{filename}"></iframe>
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
</div>'''
return [img, gradio_html]
except Exception as e:
logger.error(f"Error generating level: {e}")
raise gr.Error(f"Failed to generate level: {str(e)}")
with gr.Blocks().queue() as demo:
gr.Markdown('''# MarioGPT
### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?")
elevation = gr.Radio(["low", "high"], value="low", label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
btn = gr.Button("Generate level")
with gr.Row():
with gr.Group():
level_play = gr.HTML()
level_image = gr.Image()
btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
gr.Examples(
examples=[
["many", "many", "some", "high"],
["no", "some", "many", "high"],
["many", "many", "little", "low"],
["no", "no", "many", "high"],
],
inputs=[pipes, enemies, blocks, elevation],
outputs=[level_image, level_play],
fn=generate,
cache_examples=True,
)
# Mount static files and Gradio app
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)