Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr, subprocess, tempfile, sys, os, shutil
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from huggingface_hub import snapshot_download
|
| 4 |
+
import spaces, torch
|
| 5 |
+
|
| 6 |
+
MODEL_REPO = "Skywork/Matrix-Game-2.0"
|
| 7 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
+
print("Device:", DEVICE)
|
| 9 |
+
|
| 10 |
+
# ----- one-time model + code download -----
|
| 11 |
+
@spaces.cached
|
| 12 |
+
def setup():
|
| 13 |
+
print("‣ downloading weights …")
|
| 14 |
+
model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache")
|
| 15 |
+
if not os.path.exists("Matrix-Game"):
|
| 16 |
+
subprocess.check_call(["git", "clone",
|
| 17 |
+
"https://github.com/SkyworkAI/Matrix-Game.git"])
|
| 18 |
+
return model_dir
|
| 19 |
+
# -----------------------------------------
|
| 20 |
+
|
| 21 |
+
@spaces.GPU(duration=120)
|
| 22 |
+
def run(img, frames, seed):
|
| 23 |
+
if img is None:
|
| 24 |
+
return None, "Upload an image first!"
|
| 25 |
+
model_dir = setup()
|
| 26 |
+
|
| 27 |
+
tmp = tempfile.mkdtemp()
|
| 28 |
+
inp = os.path.join(tmp, "input.jpg")
|
| 29 |
+
outd = os.path.join(tmp, "outputs")
|
| 30 |
+
os.makedirs(outd, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# down-size to <=512 px to keep VRAM happy
|
| 33 |
+
if max(img.size) > 512:
|
| 34 |
+
r = 512 / max(img.size)
|
| 35 |
+
img = img.resize((int(img.size[0]*r), int(img.size[1]*r)),
|
| 36 |
+
Image.Resampling.LANCZOS)
|
| 37 |
+
img.save(inp)
|
| 38 |
+
|
| 39 |
+
m2 = os.path.join("Matrix-Game", "Matrix-Game-2")
|
| 40 |
+
cmd = [sys.executable, os.path.join(m2, "inference.py"),
|
| 41 |
+
"--img_path", inp,
|
| 42 |
+
"--output_folder", outd,
|
| 43 |
+
"--num_output_frames", str(frames),
|
| 44 |
+
"--seed", str(seed),
|
| 45 |
+
"--pretrained_model_path", model_dir]
|
| 46 |
+
|
| 47 |
+
print("‣ running:", " ".join(cmd))
|
| 48 |
+
proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2)
|
| 49 |
+
print(proc.stdout or proc.stderr)
|
| 50 |
+
|
| 51 |
+
# grab first video file we find
|
| 52 |
+
for root, _, files in os.walk(outd):
|
| 53 |
+
for f in files:
|
| 54 |
+
if f.lower().endswith((".mp4", ".webm", ".mov")):
|
| 55 |
+
final = os.path.join(root, f)
|
| 56 |
+
shutil.move(final, "result.mp4")
|
| 57 |
+
shutil.rmtree(tmp, ignore_errors=True)
|
| 58 |
+
return "result.mp4", "✔ Done!"
|
| 59 |
+
return None, "Generation failed – see logs"
|
| 60 |
+
|
| 61 |
+
with gr.Blocks() as demo:
|
| 62 |
+
gr.Markdown("# Matrix-Game 2.0 demo")
|
| 63 |
+
with gr.Row():
|
| 64 |
+
with gr.Column():
|
| 65 |
+
img = gr.Image(label="Start frame (jpg/png)", type="pil")
|
| 66 |
+
nfrm = gr.Slider(25, 150, 60, step=1, label="Frames")
|
| 67 |
+
s = gr.Number(42, label="Seed")
|
| 68 |
+
go = gr.Button("Generate")
|
| 69 |
+
with gr.Column():
|
| 70 |
+
vid = gr.Video(label="Output")
|
| 71 |
+
stat = gr.Textbox(label="Status")
|
| 72 |
+
go.click(run, [img, nfrm, s], [vid, stat])
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
demo.launch()
|