Ashton99 commited on
Commit
3ac36b2
·
verified ·
1 Parent(s): cbe8bc1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr, subprocess, tempfile, sys, os, shutil
2
+ from PIL import Image
3
+ from huggingface_hub import snapshot_download
4
+ import spaces, torch
5
+
6
+ MODEL_REPO = "Skywork/Matrix-Game-2.0"
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
8
+ print("Device:", DEVICE)
9
+
10
+ # ----- one-time model + code download -----
11
+ @spaces.cached
12
+ def setup():
13
+ print("‣ downloading weights …")
14
+ model_dir = snapshot_download(MODEL_REPO, cache_dir="model_cache")
15
+ if not os.path.exists("Matrix-Game"):
16
+ subprocess.check_call(["git", "clone",
17
+ "https://github.com/SkyworkAI/Matrix-Game.git"])
18
+ return model_dir
19
+ # -----------------------------------------
20
+
21
+ @spaces.GPU(duration=120)
22
+ def run(img, frames, seed):
23
+ if img is None:
24
+ return None, "Upload an image first!"
25
+ model_dir = setup()
26
+
27
+ tmp = tempfile.mkdtemp()
28
+ inp = os.path.join(tmp, "input.jpg")
29
+ outd = os.path.join(tmp, "outputs")
30
+ os.makedirs(outd, exist_ok=True)
31
+
32
+ # down-size to <=512 px to keep VRAM happy
33
+ if max(img.size) > 512:
34
+ r = 512 / max(img.size)
35
+ img = img.resize((int(img.size[0]*r), int(img.size[1]*r)),
36
+ Image.Resampling.LANCZOS)
37
+ img.save(inp)
38
+
39
+ m2 = os.path.join("Matrix-Game", "Matrix-Game-2")
40
+ cmd = [sys.executable, os.path.join(m2, "inference.py"),
41
+ "--img_path", inp,
42
+ "--output_folder", outd,
43
+ "--num_output_frames", str(frames),
44
+ "--seed", str(seed),
45
+ "--pretrained_model_path", model_dir]
46
+
47
+ print("‣ running:", " ".join(cmd))
48
+ proc = subprocess.run(cmd, capture_output=True, text=True, cwd=m2)
49
+ print(proc.stdout or proc.stderr)
50
+
51
+ # grab first video file we find
52
+ for root, _, files in os.walk(outd):
53
+ for f in files:
54
+ if f.lower().endswith((".mp4", ".webm", ".mov")):
55
+ final = os.path.join(root, f)
56
+ shutil.move(final, "result.mp4")
57
+ shutil.rmtree(tmp, ignore_errors=True)
58
+ return "result.mp4", "✔ Done!"
59
+ return None, "Generation failed – see logs"
60
+
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("# Matrix-Game 2.0 demo")
63
+ with gr.Row():
64
+ with gr.Column():
65
+ img = gr.Image(label="Start frame (jpg/png)", type="pil")
66
+ nfrm = gr.Slider(25, 150, 60, step=1, label="Frames")
67
+ s = gr.Number(42, label="Seed")
68
+ go = gr.Button("Generate")
69
+ with gr.Column():
70
+ vid = gr.Video(label="Output")
71
+ stat = gr.Textbox(label="Status")
72
+ go.click(run, [img, nfrm, s], [vid, stat])
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch()