File size: 5,340 Bytes
20b52a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Adapted from https://github.com/bootphon/spokenlm-phoneme

import torch
import torchaudio
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchaudio.models.wav2vec2 import components
from torchaudio.pipelines import HUBERT_BASE
from typing import Iterable


class Tokenizer:
    # fmt:off
    PHONEMES = {
        "SIL": 0, "AA": 1, "AE": 2, "AH": 3, "AO": 4, "AW": 5, "AY": 6, "B": 7,
        "CH": 8, "D": 9, "DH": 10, "EH": 11, "ER": 12, "EY": 13, "F": 14, "G": 15,
        "HH": 16, "IH": 17, "IY": 18, "JH": 19, "K": 20, "L": 21, "M": 22, "N": 23,
        "NG": 24, "OW": 25, "OY": 26, "P": 27, "R": 28, "S": 29, "SH": 30, "T": 31,
        "TH": 32, "UH": 33, "UW": 34, "V": 35, "W": 36, "Y": 37, "Z": 38, "ZH": 39,
    }
    # fmt:on

    def __init__(self, with_blank: bool = False) -> None:
        self.token_to_id = self.PHONEMES | {"<pad>": self.pad_id}
        self.id_to_token = {v: k for k, v in self.token_to_id.items()}
        self.with_blank = with_blank

    @property
    def vocab_size(self) -> int:
        if self.with_blank:
            return len(self.PHONEMES) + 1
        return len(self.PHONEMES)

    @property
    def silence_id(self) -> int:
        return self.PHONEMES["SIL"]

    @property
    def pad_id(self) -> int:
        return len(self.PHONEMES)

    def encode(self, phones: "list[str] | str") -> torch.LongTensor:
        if isinstance(phones, str):
            phones = phones.split(" ")
        return torch.LongTensor([self.token_to_id[phone] for phone in phones])

    def decode(self, tokens: Iterable[int]) -> str:
        return " ".join(
            self.id_to_token[int(token)]
            for token in tokens
            if token < self.pad_id and int(token) != self.silence_id
        )


FINETUNING_HUBERT_CONFIG = {
    "encoder_projection_dropout": 0,
    "encoder_attention_dropout": 0,
    "encoder_ff_interm_dropout": 0.1,
    "encoder_dropout": 0,
    "encoder_layer_drop": 0.1,  # In torchaudio: 0.05
    "mask_prob": 0.75,  # In torchaudio: 0.65
    "mask_channel_prob": 0.5,
    "mask_channel_length": 10,  # In torchaudio and fairseq: 64. This is the value for pretraining.
    "num_classes": 500,  # Number of classes during HuBERT pretraining.
}


class HuBERTPhoneme(nn.Module, PyTorchModelHubMixin):
    def __init__(self, freeze_encoder: bool = True, ctc_training: bool = False) -> None:
        """Initialize the model.

        Parameters
        ----------
        freeze_encoder : bool, optional
            Whether to freeze the Transformer encoder of HuBERT, by default True.
            The convolutional layers are always frozen.
        """
        super().__init__()
        self.model = torchaudio.models.hubert_pretrain_base(**FINETUNING_HUBERT_CONFIG)
        self.model.wav2vec2.load_state_dict(HUBERT_BASE.get_model().state_dict())
        self.aux = nn.Linear(
            HUBERT_BASE._params["encoder_embed_dim"],
            Tokenizer(with_blank=ctc_training).vocab_size,
        )
        self.freeze_encoder = freeze_encoder
        self.ctc_training = ctc_training

    def forward(
        self, waveforms: Tensor, lengths: "Tensor | None" = None
    ) -> "tuple[Tensor, Tensor | None]":
        """Extract logits during training, with masking."""
        if self.freeze_encoder:
            with torch.no_grad():
                x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
                padding_mask = components._get_padding_mask(x, out_len)
                x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)  # type: ignore
                x, _ = self.model.mask_generator(x, padding_mask)
                x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)  # type: ignore
        else:
            with torch.no_grad():
                x, out_len = self.model.wav2vec2.feature_extractor(waveforms, lengths)
                padding_mask = components._get_padding_mask(x, out_len)
            x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)  # type: ignore
            x, _ = self.model.mask_generator(x, padding_mask)
            x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)  # type: ignore
        logits = self.aux(x)
        return logits, out_len

    def inference(
        self, waveforms: Tensor, lengths: "Tensor | None" = None
    ) -> "tuple[Tensor, Tensor | None]":
        """Extract logits during inference. No masking is applied."""
        x, out_len = self.model.wav2vec2(waveforms, lengths)
        logits = self.aux(x)
        return logits, out_len

    @torch.jit.export
    def extract_features(
        self, waveforms: Tensor, lengths: "Tensor | None" = None
    ) -> "tuple[list[Tensor], Tensor | None]":
        """Extract features from intermediate layers. No masking is applied."""
        x, out_len = self.model.wav2vec2.extract_features(waveforms, lengths)
        x.append(self.aux(x[-1]))
        return x, out_len

    def train(self, mode: bool = True) -> "HuBERTPhoneme":
        """Override the train method to set the encoder in eval mode if it is frozen."""
        if self.freeze_encoder:
            self.model.wav2vec2.eval()
        else:
            self.model.wav2vec2.train(mode)
        self.aux.train(mode)
        return self