Javedalam commited on
Commit
5d5100f
·
verified ·
1 Parent(s): 69841d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from snac import SNAC
5
+ import soundfile as sf
6
+ import tempfile
7
+ import spaces
8
+
9
+ # --- global handles (lazy-loaded) ---
10
+ model = None
11
+ tokenizer = None
12
+ snac_model = None
13
+
14
+ def load_models(device: str):
15
+ """Load Maya1 and SNAC once, with device-aware dtype."""
16
+ global model, tokenizer, snac_model
17
+
18
+ if tokenizer is None or model is None:
19
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
20
+ print(f"[load_models] loading Maya1 (dtype={dtype}, device={device})")
21
+
22
+ # device_map only on CUDA; on CPU keep None to avoid accelerate errors
23
+ device_map = "auto" if device == "cuda" else None
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ "maya-research/maya1",
27
+ torch_dtype=dtype,
28
+ device_map=device_map,
29
+ trust_remote_code=True,
30
+ )
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ "maya-research/maya1",
33
+ trust_remote_code=True,
34
+ )
35
+ if tokenizer.pad_token is None:
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+
38
+ if snac_model is None:
39
+ print("[load_models] loading SNAC 24kHz decoder")
40
+ snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
41
+ # move later inside handler (after ZeroGPU alloc)
42
+ return snac
43
+ return None
44
+
45
+ @spaces.GPU(duration=180)
46
+ def generate_speech(text, voice_description, temperature, top_p, max_tokens):
47
+ if not text.strip():
48
+ raise gr.Error("Enter some text.")
49
+ if not voice_description.strip():
50
+ voice_description = "Realistic voice with neutral tone and conversational pacing."
51
+
52
+ # ZeroGPU gives us CUDA during this call
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+
55
+ # load / ensure models exist
56
+ snac_fresh = load_models(device) # returns SNAC if created
57
+ global snac_model
58
+ if snac_fresh is not None:
59
+ snac_model = snac_fresh
60
+
61
+ # move models to the active device (ZeroGPU alloc happened)
62
+ if device == "cuda":
63
+ model.to(device)
64
+ snac_model.to(device)
65
+
66
+ # prompt exactly like the model card
67
+ prompt = f'<description="{voice_description}"> {text}'
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
69
+
70
+ with torch.inference_mode():
71
+ outputs = model.generate(
72
+ input_ids=inputs["input_ids"],
73
+ attention_mask=inputs.get("attention_mask"),
74
+ max_new_tokens=int(max_tokens),
75
+ temperature=float(temperature),
76
+ top_p=float(top_p),
77
+ do_sample=True,
78
+ pad_token_id=tokenizer.pad_token_id,
79
+ eos_token_id=None,
80
+ repetition_penalty=1.1,
81
+ )
82
+
83
+ # SNAC token extraction (7-token frames) — as per model card
84
+ generated_ids = outputs[0, inputs["input_ids"].shape[1]:]
85
+ snac_tokens = [t.item() for t in generated_ids if 128266 <= t <= 156937]
86
+ frames = len(snac_tokens) // 7
87
+ if frames == 0:
88
+ raise gr.Error("No SNAC tokens generated. Try longer text and max_tokens=1200–1500.")
89
+
90
+ codes = [[], [], []]
91
+ for i in range(frames):
92
+ s = snac_tokens[i*7:(i+1)*7]
93
+ codes[0].append((s[0]-128266) % 4096)
94
+ codes[1].extend([(s[1]-128266) % 4096, (s[4]-128266) % 4096])
95
+ codes[2].extend([
96
+ (s[2]-128266) % 4096,
97
+ (s[3]-128266) % 4096,
98
+ (s[5]-128266) % 4096,
99
+ (s[6]-128266) % 4096,
100
+ ])
101
+
102
+ codes_tensor = [torch.tensor(c, dtype=torch.long, device=device).unsqueeze(0) for c in codes]
103
+ with torch.inference_mode():
104
+ audio = snac_model.decoder(snac_model.quantizer.from_codes(codes_tensor))[0, 0].cpu().numpy()
105
+
106
+ # write to wav; return filepath for gr.Audio(type="filepath")
107
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
108
+ sf.write(f.name, audio, 24000)
109
+ return f.name
110
+
111
+ # ------------------- UI -------------------
112
+ voice_presets = {
113
+ "Male - American": "Realistic male voice in the 30s age with american accent. Normal pitch, warm timbre, conversational pacing.",
114
+ "Female - British": "Clear female voice in the 20s age with British accent. Pleasant tone, articulate delivery, moderate pacing.",
115
+ "Male - Deep": "Deep male voice with authoritative tone. Low pitch, resonant timbre, steady pacing.",
116
+ "Female - Energetic": "Energetic female voice with enthusiastic tone. Higher pitch, bright timbre, upbeat pacing.",
117
+ "Neutral - Professional": "Professional neutral voice with clear articulation. Balanced pitch, crisp tone, measured pacing.",
118
+ "Custom": ""
119
+ }
120
+
121
+ def update_voice_description(preset): return voice_presets.get(preset, "")
122
+
123
+ with gr.Blocks(theme=gr.themes.Soft(), title="Maya1 Text-to-Speech") as demo:
124
+ gr.HTML("""
125
+ <div style="text-align:center;padding:16px">
126
+ <h1>🎙️ Maya1 Text-to-Speech</h1>
127
+ <p style="color:#666">Generate emotional, realistic speech with natural-language voice design</p>
128
+ <p style="font-size:12px;color:#28a745">⚡ ZeroGPU inference</p>
129
+ </div>
130
+ """)
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=1):
134
+ text_input = gr.Textbox(
135
+ label="Text to Speak",
136
+ value="Hello! This is Maya1 <laugh> the best open source voice AI model with emotions.",
137
+ lines=5,
138
+ )
139
+ voice_preset = gr.Dropdown(choices=list(voice_presets.keys()),
140
+ value="Male - American",
141
+ label="Voice Preset")
142
+ voice_description = gr.Textbox(
143
+ label="Voice Description",
144
+ value=voice_presets["Male - American"],
145
+ lines=3,
146
+ )
147
+ with gr.Accordion("Advanced", open=False):
148
+ temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
149
+ top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p")
150
+ max_tokens = gr.Slider(500, 2000, value=1000, step=100, label="Max tokens")
151
+
152
+ generate_btn = gr.Button("🎤 Generate Speech", variant="primary")
153
+
154
+ with gr.Column(scale=1):
155
+ audio_output = gr.Audio(label="Generated Speech", type="filepath", interactive=False)
156
+
157
+ voice_preset.change(fn=update_voice_description, inputs=[voice_preset], outputs=[voice_description])
158
+ generate_btn.click(fn=generate_speech,
159
+ inputs=[text_input, voice_description, temperature, top_p, max_tokens],
160
+ outputs=[audio_output])
161
+
162
+ # Register an explicit API endpoint so Spaces never shows “No API found”
163
+ gr.api(fn=generate_speech, name="generate_speech")
164
+
165
+ if __name__ == "__main__":
166
+ demo.queue()
167
+ demo.launch(show_error=True)