File size: 1,559 Bytes
fc0ff8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch.nn as nn


def unwrap_model(model):
    """
    Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
    """
    if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
        return model.module
    else:
        return model


def get_label(lang_x, tokenizer, mode='colon'):
    eoc_token = '<|endofchunk|>'
    media_token = '<image>'
    colon_token_id = tokenizer.encode(':')[0]
    eoc_token_id = tokenizer.additional_special_tokens_ids[
        tokenizer.additional_special_tokens.index(eoc_token)
    ]
    media_token_id = tokenizer.additional_special_tokens_ids[
        tokenizer.additional_special_tokens.index(media_token)
    ]
    label = lang_x.clone()
    # compute context len, by getting the index of the last colon token
    for idx in range(len(label)):
        if mode == 'colon':
            # get the last occurence of the ':' token
            # get a tensor of True/False values, then use torch.nonzero to get the indices
            indices = (label[idx] == colon_token_id).nonzero().flatten()
            # Then get the last occurrence
            end_of_context = indices[-1].item() + 1  # +1 because we want to include the colon token
        elif isinstance(mode, int):
            end_of_context = -label[idx].tolist()[::-1].index(media_token_id) - 1 + mode
        label[idx, : end_of_context] = -100
    label[label == tokenizer.pad_token_id] = -100
    label[:, 0] = -100
    label[label == media_token_id] = -100
    label[label == eoc_token_id] = -100
    return label