import gradio as gr
import torch
import uuid
import spaces
from supermariogpt.dataset import MarioDataset
from supermariogpt.prompter import Prompter
from supermariogpt.lm import MarioLM
from supermariogpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
from pathlib import Path
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize model
try:
mario_lm = MarioLM()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mario_lm = mario_lm.to(device)
logger.info(f"Model loaded successfully on {device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
TILE_DIR = "data/tiles"
# Ensure static directory exists
Path("static").mkdir(exist_ok=True)
gr.set_static_paths(paths=[Path("static").absolute()])
app = FastAPI()
def make_html_file(generated_level):
"""Generate HTML file for level visualization"""
try:
level_text = f"""{'''
'''.join(view_level(generated_level, mario_lm.tokenizer))}"""
unique_id = uuid.uuid4() # Changed from uuid1 to uuid4 for better randomness
html_filename = f"demo-{unique_id}.html"
html_content = f'''
supermariogpt
'''
with open(Path("static") / html_filename, 'w', encoding='utf-8') as f:
f.write(html_content)
return html_filename
except Exception as e:
logger.error(f"Error creating HTML file: {e}")
raise
@spaces.GPU
def generate(pipes, enemies, blocks, elevation, temperature=2.0, level_size=1399, prompt="", progress=gr.Progress(track_tqdm=True)):
"""Generate Mario level based on parameters"""
try:
# Validate inputs
temperature = max(0.1, min(2.0, float(temperature)))
level_size = max(100, min(2799, int(level_size)))
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
logger.info(f"Using prompt: {prompt}")
logger.info(f"Using temperature: {temperature}")
logger.info(f"Using level size: {level_size}")
prompts = [prompt]
generated_level = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=float(temperature),
use_tqdm=True
)
filename = make_html_file(generated_level)
img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = f'''
Press the arrow keys to move. Press a to run, s to jump and d to shoot fireflowers
'''
return [img, gradio_html]
except Exception as e:
logger.error(f"Error generating level: {e}")
raise gr.Error(f"Failed to generate level: {str(e)}")
with gr.Blocks().queue() as demo:
gr.Markdown('''# MarioGPT
### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], value="some", label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], value="some", label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], value="some", label="How many blocks?")
elevation = gr.Radio(["low", "high"], value="low", label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(value=2.0, minimum=0.1, maximum=2.0, step=0.1, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Slider(value=1399, minimum=100, maximum=2799, step=1, label="level_size")
btn = gr.Button("Generate level")
with gr.Row():
with gr.Group():
level_play = gr.HTML()
level_image = gr.Image()
btn.click(fn=generate, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt], outputs=[level_image, level_play])
gr.Examples(
examples=[
["many", "many", "some", "high"],
["no", "some", "many", "high"],
["many", "many", "little", "low"],
["no", "no", "many", "high"],
],
inputs=[pipes, enemies, blocks, elevation],
outputs=[level_image, level_play],
fn=generate,
cache_examples=True,
)
# Mount static files and Gradio app
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)