File size: 3,969 Bytes
7bdc939
09234e6
 
 
 
 
 
 
7bdc939
1ff84be
09234e6
7bdc939
 
 
09234e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdc939
09234e6
7bdc939
 
09234e6
 
 
 
 
 
 
 
 
9dc2a58
7bdc939
 
09234e6
 
 
 
 
7bdc939
 
 
 
9dc2a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bdc939
 
09234e6
 
7bdc939
 
15e3b90
a698d06
09234e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os, torch, gradio as gr
from transformers import (
    AutoModelForCTC,
    AutoProcessor,         # happy path
    Wav2Vec2Processor,     # fallback
    Wav2Vec2FeatureExtractor,
    Wav2Vec2CTCTokenizer,
)

MODEL_ID = os.getenv("MODEL_ID", "Reihaneh/wav2vec2_fy_nl_best_frisian_1")
HF_TOKEN = os.getenv("HF_TOKEN")  # only if private

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---- Try to load processor; if missing feature extractor, build it manually
processor = None
try:
    processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
except Exception as e:
    print("AutoProcessor failed, building Wav2Vec2Processor manually:", e)
    # Load tokenizer (must exist in repo: vocab.json + tokenizer_config.json)
    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
    # Minimal, safe defaults — adjust if your training used different settings
    feature_extractor = Wav2Vec2FeatureExtractor(
        feature_size=1,
        sampling_rate=16000,          # <-- set to your training SR
        padding_value=0.0,
        do_normalize=True,
        return_attention_mask=True,
    )
    processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

model = AutoModelForCTC.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device).eval()

# Try to read SR from processor if present
target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)

def _cheap_resample(wav, sr, target_sr):
    if sr == target_sr:
        return wav
    import numpy as np, math
    ratio = target_sr / sr
    idx = (np.arange(int(math.ceil(wav.shape[0] * ratio))) / ratio).astype(int)
    idx = idx[idx < wav.shape[0]]
    return wav[idx]

'''def transcribe(audio):
    if audio is None:
        return ""
    sr, x = audio
    if x.ndim == 2:  # stereo -> mono
        x = x[:, 0]
    x = _cheap_resample(x, sr, target_sr)
    inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True)
    with torch.inference_mode():
        logits = model(inputs.input_values.to(device)).logits
        ids = torch.argmax(logits, dim=-1)
        text = processor.batch_decode(ids)[0]
    return text'''


def transcribe(a):
    try:
        if a is None:
            return ""
        sr, x = a  # if you use a helper, just make sure you end up with (sr, np.ndarray)

        # 1) mono + sanitize + FORCE float32
        import numpy as np, math
        if x.ndim == 2:
            x = x.mean(axis=1)
        x = np.nan_to_num(x).astype(np.float32)

        # 2) (optional) cheap resample to your processor’s SR
        target_sr = getattr(getattr(processor, "feature_extractor", None), "sampling_rate", 16000)
        if sr != target_sr:
            ratio = target_sr / float(sr)
            n = int(math.ceil(len(x) * ratio))
            idx = (np.arange(n) / ratio).astype(np.int64)
            idx = np.clip(idx, 0, len(x) - 1)
            x = x[idx]

        # 3) tokenize → cast inputs to DEVICE + MODEL DTYPE
        inputs = processor(x, sampling_rate=target_sr, return_tensors="pt", padding=True)
        input_values = inputs.input_values.to(device)

        # >>> KEY LINE: match model dtype (prevents "Input type (double) and bias type should be the same")
        input_values = input_values.to(model.dtype)

        with torch.inference_mode():
            logits = model(input_values).logits
            ids = torch.argmax(logits, dim=-1)
            text = processor.batch_decode(ids)[0]
        return text
    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return f"⚠️ Error: {e}"


with gr.Blocks(title="Frisian ASR") as demo:
    gr.Markdown("## 🎙️ Frisian ASR")
    audio = gr.Audio(sources=["microphone","upload"], type="numpy", label="Audio")
    out = gr.Textbox(label="Transcript")
    gr.Button("Transcribe").click(transcribe, inputs=audio, outputs=out)


demo.queue().launch()