|
|
import os |
|
|
import spaces |
|
|
import pickle |
|
|
import subprocess |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import gradio as gr |
|
|
from dataclasses import asdict |
|
|
from transformers import T5Tokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
from time import time_ns |
|
|
from uuid import uuid4 |
|
|
|
|
|
from transformer_model import Transformer |
|
|
from pyharp.core import ModelCard, build_endpoint |
|
|
from pyharp.labels import LabelList |
|
|
|
|
|
|
|
|
REPO_ID = "amaai-lab/text2midi" |
|
|
MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin") |
|
|
TOKENIZER_PATH = hf_hub_download(repo_id=REPO_ID, filename="vocab_remi.pkl") |
|
|
|
|
|
SOUNDFONT_PATH = hf_hub_download(repo_id=REPO_ID, filename="soundfont.sf2") |
|
|
|
|
|
|
|
|
|
|
|
def save_wav(midi_path: str) -> str: |
|
|
directory = os.path.dirname(midi_path) or "." |
|
|
stem = os.path.splitext(os.path.basename(midi_path))[0] |
|
|
midi_filepath = os.path.join(directory, f"{stem}.mid") |
|
|
wav_filepath = os.path.join(directory, f"{stem}.wav") |
|
|
cmd = ( |
|
|
f"fluidsynth -r 16000 {SOUNDFONT_PATH} -g 1.0 --quiet --no-shell " |
|
|
f"{midi_filepath} -T wav -F {wav_filepath} > /dev/null" |
|
|
) |
|
|
subprocess.run(cmd, shell=True, check=False) |
|
|
return wav_filepath |
|
|
|
|
|
|
|
|
def _unique_path(ext: str) -> str: |
|
|
"""Create a unique file path in /tmp to avoid naming collisions.""" |
|
|
return os.path.join("/tmp", f"t2m_{time_ns()}_{uuid4().hex[:8]}{ext}") |
|
|
|
|
|
|
|
|
|
|
|
def generate_midi(prompt: str, temperature: float = 0.9, max_len: int = 500) -> str: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
with open(TOKENIZER_PATH, "rb") as f: |
|
|
r_tokenizer = pickle.load(f) |
|
|
|
|
|
vocab_size = len(r_tokenizer) |
|
|
model = Transformer( |
|
|
vocab_size, |
|
|
768, |
|
|
8, |
|
|
2048, |
|
|
18, |
|
|
1024, |
|
|
False, |
|
|
8, |
|
|
device=device |
|
|
) |
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base") |
|
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) |
|
|
|
|
|
input_ids = nn.utils.rnn.pad_sequence(inputs.input_ids, batch_first=True, padding_value=0).to(device) |
|
|
attention_mask = nn.utils.rnn.pad_sequence(inputs.attention_mask, batch_first=True, padding_value=0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate(input_ids, attention_mask, max_len=max_len, temperature=temperature) |
|
|
|
|
|
output_list = output[0].tolist() |
|
|
generated_midi = r_tokenizer.decode(output_list) |
|
|
|
|
|
midi_path = _unique_path(".mid") |
|
|
generated_midi.dump_midi(midi_path) |
|
|
return midi_path |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def process_fn(prompt: str, temperature: float, max_length: int): |
|
|
try: |
|
|
midi_path = generate_midi(prompt, float(temperature), int(max_length)) |
|
|
labels = LabelList() |
|
|
return asdict(labels), midi_path |
|
|
except Exception as e: |
|
|
|
|
|
return {"message": f"Error: {e}"}, None |
|
|
|
|
|
|
|
|
model_card = ModelCard( |
|
|
name="Text2MIDI Generation", |
|
|
description=( |
|
|
"Turn your musical ideas into playable MIDI notes. \n" |
|
|
"Input: Describe what you'd like to hear. For example: a gentle piano lullaby with soft strings. \n" |
|
|
"Output: This model will generate a matching MIDI sequence for playback or editing. \n" |
|
|
"Use the sliders to control the amount of creativity and length." |
|
|
), |
|
|
author="Keshav Bhandari, Abhinaba Roy, Kyra Wang, Geeta Puri, Simon Colton, Dorien Herremans", |
|
|
tags=["text-to-music", "midi", "generation"] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## 🎶 text2midi") |
|
|
|
|
|
|
|
|
prompt_in = gr.Textbox( |
|
|
label="Describe Your Music", |
|
|
info="Type a short phrase like 'calm piano with flowing arpeggios' ", |
|
|
).harp_required(True) |
|
|
temperature_in = gr.Slider(minimum=0.8, maximum=1.1, value=0.9, step=0.1, label="Creativity", info=( |
|
|
"Adjusts how much freedom the model takes while composing.\n" |
|
|
"Lower = safer and more predictable (structured), " |
|
|
"Higher = more varied and expressive." |
|
|
), interactive=True) |
|
|
|
|
|
maxlen_in = gr.Slider(minimum=500, maximum=1500, step=100, value=500, label="Composition Length", info=( |
|
|
"Determines how long the generated piece is in musical tokens.\n" |
|
|
"Higher values produce longer phrases (roughly more measures of music).") |
|
|
) |
|
|
|
|
|
|
|
|
labels_out = gr.JSON(label="Labels / Metadata") |
|
|
midi_out = gr.File(label="Generated MIDI", file_types=[".mid", ".midi"], type="filepath") |
|
|
|
|
|
|
|
|
_ = build_endpoint( |
|
|
model_card=model_card, |
|
|
input_components=[prompt_in, temperature_in, maxlen_in], |
|
|
output_components=[labels_out, midi_out], |
|
|
process_fn=process_fn |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch(share=True, show_error=True, debug=True) |