saumya-pailwan's picture
new line fix
3b314da verified
import os
import spaces
import pickle
import subprocess
import torch
import torch.nn as nn
import gradio as gr
from dataclasses import asdict
from transformers import T5Tokenizer
from huggingface_hub import hf_hub_download
from time import time_ns
from uuid import uuid4
from transformer_model import Transformer
from pyharp.core import ModelCard, build_endpoint
from pyharp.labels import LabelList
# Model/artifacts from HF Hub
REPO_ID = "amaai-lab/text2midi"
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin")
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl")
# Optional, only if you later add WAV preview:
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2")
# (Optional) MIDI -> WAV
def save_wav(midi_path: str) -> str:
directory = os.path.dirname(midi_path) or "."
stem = os.path.splitext(os.path.basename(midi_path))[0]
midi_filepath = os.path.join(directory, f"{stem}.mid")
wav_filepath = os.path.join(directory, f"{stem}.wav")
cmd = (
f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell "
f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null"
)
subprocess.run(cmd, shell=True, check=False)
return wav_filepath
# Helpers
def _unique_path(ext: str) -> str:
"""Create a unique file path in /tmp to avoid naming collisions."""
return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}")
# Core Text -> MIDI
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load REMI vocab/tokenizer (pickle dict used by the provided model)
with open(TOKENIZER_PATH, "rb") as f:
r_tokenizer = pickle.load(f)
vocab_size = len(r_tokenizer)
model = Transformer(
vocab_size, # vocab size
768, # d_model
8, # nhead
2048, # dim_feedforward
18, # nlayers
1024, # max_seq_len
False, # use_rotary
8, # rotary_dim
device=device # device
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device)
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device)
with torch.no_grad():
output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature)
output_list = output[0].tolist()
generated_midi = r_tokenizer.decode(output_list)
midi_path = _unique_path(".mid")
generated_midi.dump_midi(midi_path)
return midi_path
# HARP process function
# Return JSON first, MIDI second
@spaces.GPU(duration=120)
def process_fn(prompt: str, temperature: float, max_length: int):
try:
midi_path = generate_midi(prompt, float(temperature), int(max_length))
labels = LabelList() # add MidiLabel entries here if you have metadata
return asdict(labels), midi_path
except Exception as e:
# On error: return JSON with error message, and no file
return {"message": f"Error: {e}"}, None
# HARP Model Card
model_card = ModelCard(
name="Text2MIDI Generation",
description=(
"Turn your musical ideas into playable MIDI notes. \n"
"Input: Describe what you'd like to hear. For example: a gentle piano lullaby with soft strings. \n"
"Output: This model will generate a matching MIDI sequence for playback or editing. \n"
"Use the sliders to control the amount of creativity and length."
),
author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans",
tags=["text-to-music", "midi", "generation"]
)
# Gradio + HARP UI
with gr.Blocks() as demo:
gr.Markdown("## 🎶 text2midi")
# Inputs
prompt_in = gr.Textbox(
label="Describe Your Music",
info="Type a short phrase like 'calm piano with flowing arpeggios' ",
).harp_required(True)
temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Creativity", info=(
"Adjusts how much freedom the model takes while composing.\n"
"Lower = safer and more predictable (structured), "
"Higher = more varied and expressive."
), interactive=True)
maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Composition Length", info=(
"Determines how long the generated piece is in musical tokens.\n"
"Higher values produce longer phrases (roughly more measures of music).")
)
# Outputs (JSON FIRST for HARP, then MIDI)
labels_out = gr.JSON(label="Labels / Metadata")
midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath")
# Build HARP endpoint
_ = build_endpoint(
model_card=model_card,
input_components=[prompt_in, temperature_in, maxlen_in],
output_components=[labels_out, midi_out], # JSON first
process_fn=process_fn
)
# Launch App
demo.launch(share=True, show_error=True, debug=True)