Spaces:
Runtime error
Runtime error
| import datasets | |
| import evaluate | |
| import re | |
| _DESCRIPTION = """ | |
| VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers: | |
| $$ | |
| \\text{Acc}(ans) = \\min \\left( \\frac{\\text{# humans that said }ans}{3}, 1 \\right) | |
| $$ | |
| Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| predictions (`list` of `str`): Predicted answers. | |
| references (`list` of `str` lists): Ground truth answers. | |
| answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions. | |
| questions_type (`list` of `str`, *optional*): Question types corresponding to each questions. | |
| Returns: | |
| visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 100. | |
| """ | |
| _CITATION = """ | |
| @InProceedings{{VQA}, | |
| author = {Stanislaw Antol and Aishwarya Agrawal and Jiasen Lu and Margaret Mitchell and Dhruv Batra and C. Lawrence Zitnick and Devi Parikh}, | |
| title = {{VQA}: {V}isual {Q}uestion {A}nswering}, | |
| booktitle = {International Conference on Computer Vision (ICCV)}, | |
| year = {2015}, | |
| } | |
| """ | |
| contractions = { | |
| "aint": "ain't", | |
| "arent": "aren't", | |
| "cant": "can't", | |
| "couldve": "could've", | |
| "couldnt": "couldn't", | |
| "couldn'tve": "couldn't've", | |
| "couldnt've": "couldn't've", | |
| "didnt": "didn't", | |
| "doesnt": "doesn't", | |
| "dont": "don't", | |
| "hadnt": "hadn't", | |
| "hadnt've": "hadn't've", | |
| "hadn'tve": "hadn't've", | |
| "hasnt": "hasn't", | |
| "havent": "haven't", | |
| "hed": "he'd", | |
| "hed've": "he'd've", | |
| "he'dve": "he'd've", | |
| "hes": "he's", | |
| "howd": "how'd", | |
| "howll": "how'll", | |
| "hows": "how's", | |
| "Id've": "I'd've", | |
| "I'dve": "I'd've", | |
| "Im": "I'm", | |
| "Ive": "I've", | |
| "isnt": "isn't", | |
| "itd": "it'd", | |
| "itd've": "it'd've", | |
| "it'dve": "it'd've", | |
| "itll": "it'll", | |
| "let's": "let's", | |
| "maam": "ma'am", | |
| "mightnt": "mightn't", | |
| "mightnt've": "mightn't've", | |
| "mightn'tve": "mightn't've", | |
| "mightve": "might've", | |
| "mustnt": "mustn't", | |
| "mustve": "must've", | |
| "neednt": "needn't", | |
| "notve": "not've", | |
| "oclock": "o'clock", | |
| "oughtnt": "oughtn't", | |
| "ow's'at": "'ow's'at", | |
| "'ows'at": "'ow's'at", | |
| "'ow'sat": "'ow's'at", | |
| "shant": "shan't", | |
| "shed've": "she'd've", | |
| "she'dve": "she'd've", | |
| "she's": "she's", | |
| "shouldve": "should've", | |
| "shouldnt": "shouldn't", | |
| "shouldnt've": "shouldn't've", | |
| "shouldn'tve": "shouldn't've", | |
| "somebody'd": "somebodyd", | |
| "somebodyd've": "somebody'd've", | |
| "somebody'dve": "somebody'd've", | |
| "somebodyll": "somebody'll", | |
| "somebodys": "somebody's", | |
| "someoned": "someone'd", | |
| "someoned've": "someone'd've", | |
| "someone'dve": "someone'd've", | |
| "someonell": "someone'll", | |
| "someones": "someone's", | |
| "somethingd": "something'd", | |
| "somethingd've": "something'd've", | |
| "something'dve": "something'd've", | |
| "somethingll": "something'll", | |
| "thats": "that's", | |
| "thered": "there'd", | |
| "thered've": "there'd've", | |
| "there'dve": "there'd've", | |
| "therere": "there're", | |
| "theres": "there's", | |
| "theyd": "they'd", | |
| "theyd've": "they'd've", | |
| "they'dve": "they'd've", | |
| "theyll": "they'll", | |
| "theyre": "they're", | |
| "theyve": "they've", | |
| "twas": "'twas", | |
| "wasnt": "wasn't", | |
| "wed've": "we'd've", | |
| "we'dve": "we'd've", | |
| "weve": "we've", | |
| "werent": "weren't", | |
| "whatll": "what'll", | |
| "whatre": "what're", | |
| "whats": "what's", | |
| "whatve": "what've", | |
| "whens": "when's", | |
| "whered": "where'd", | |
| "wheres": "where's", | |
| "whereve": "where've", | |
| "whod": "who'd", | |
| "whod've": "who'd've", | |
| "who'dve": "who'd've", | |
| "wholl": "who'll", | |
| "whos": "who's", | |
| "whove": "who've", | |
| "whyll": "why'll", | |
| "whyre": "why're", | |
| "whys": "why's", | |
| "wont": "won't", | |
| "wouldve": "would've", | |
| "wouldnt": "wouldn't", | |
| "wouldnt've": "wouldn't've", | |
| "wouldn'tve": "wouldn't've", | |
| "yall": "y'all", | |
| "yall'll": "y'all'll", | |
| "y'allll": "y'all'll", | |
| "yall'd've": "y'all'd've", | |
| "y'alld've": "y'all'd've", | |
| "y'all'dve": "y'all'd've", | |
| "youd": "you'd", | |
| "youd've": "you'd've", | |
| "you'dve": "you'd've", | |
| "youll": "you'll", | |
| "youre": "you're", | |
| "youve": "you've", | |
| } | |
| manualMap = { | |
| "none": "0", | |
| "zero": "0", | |
| "one": "1", | |
| "two": "2", | |
| "three": "3", | |
| "four": "4", | |
| "five": "5", | |
| "six": "6", | |
| "seven": "7", | |
| "eight": "8", | |
| "nine": "9", | |
| "ten": "10", | |
| } | |
| articles = ["a", "an", "the"] | |
| periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)") | |
| commaStrip = re.compile(r"(\d)(\,)(\d)") | |
| punct = [ | |
| ";", | |
| r"/", | |
| "[", | |
| "]", | |
| '"', | |
| "{", | |
| "}", | |
| "(", | |
| ")", | |
| "=", | |
| "+", | |
| "\\", | |
| "_", | |
| "-", | |
| ">", | |
| "<", | |
| "@", | |
| "`", | |
| ",", | |
| "?", | |
| "!", | |
| ] | |
| def processPunctuation(inText): | |
| outText = inText | |
| for p in punct: | |
| if (p + " " in inText or " " + p in inText) or ( | |
| re.search(commaStrip, inText) != None | |
| ): | |
| outText = outText.replace(p, "") | |
| else: | |
| outText = outText.replace(p, " ") | |
| outText = periodStrip.sub("", outText, re.UNICODE) | |
| return outText | |
| def processDigitArticle(inText): | |
| outText = [] | |
| tempText = inText.lower().split() | |
| for word in tempText: | |
| word = manualMap.setdefault(word, word) | |
| if word not in articles: | |
| outText.append(word) | |
| else: | |
| pass | |
| for wordId, word in enumerate(outText): | |
| if word in contractions: | |
| outText[wordId] = contractions[word] | |
| outText = " ".join(outText) | |
| return outText | |
| class VQAAccuracy(evaluate.Metric): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string", id="sequence"), | |
| "references": datasets.Sequence( | |
| datasets.Value("string", id="sequence"), id="references" | |
| ), | |
| "answer_types": datasets.Value("string", id="sequence"), | |
| "question_types": datasets.Value("string", id="sequence"), | |
| } | |
| ), | |
| reference_urls=[ | |
| "https://visualqa.org/evaluation.html", | |
| "https://github.com/GT-Vision-Lab/VQA/blob/master", | |
| ], | |
| ) | |
| def _compute(self, predictions, references, answer_types=None, question_types=None): | |
| if answer_types is None: | |
| answer_types = [None] * len(predictions) | |
| if question_types is None: | |
| question_types = [None] * len(predictions) | |
| if not len(predictions) == len(answer_types) == len(question_types): | |
| raise ValueError( | |
| "The length of predictions, answer_types and question_types doesn't match." | |
| ) | |
| total, ans_type_dict, ques_type_dict = [], {}, {} | |
| for pred, gts, ans_type, ques_type in zip( | |
| predictions, references, answer_types, question_types | |
| ): | |
| # to align with offical data postprocess | |
| pred = pred.replace("\n", " ").replace("\t", " ").strip() | |
| pred = processDigitArticle(processPunctuation(pred)) | |
| gts = [processDigitArticle(processPunctuation(gt_ans)) for gt_ans in gts] | |
| # calculate vqa accuracy | |
| accuracy = [] | |
| for i in range(len(gts)): | |
| other_gt = gts[:i] + gts[i + 1 :] | |
| matching_ans = [item for item in other_gt if item == pred] | |
| accuracy.append(min(1, len(matching_ans) / 3)) | |
| vqa_acc = sum(accuracy) / len(accuracy) | |
| total.append(vqa_acc) | |
| if ans_type is not None: | |
| if ans_type not in ans_type_dict: | |
| ans_type_dict[ans_type] = [] | |
| ans_type_dict[ans_type].append(vqa_acc) | |
| if ques_type is not None: | |
| if ques_type not in ques_type_dict: | |
| ques_type_dict[ques_type] = [] | |
| ques_type_dict[ques_type].append(vqa_acc) | |
| # the following key names follow the naming of the official evaluation results | |
| result = {"overall": 100 * sum(total) / len(total)} | |
| if len(ans_type_dict) > 0: | |
| result["perAnswerType"] = { | |
| ans_type: 100 * sum(accuracy_list) / len(accuracy_list) | |
| for ans_type, accuracy_list in ans_type_dict.items() | |
| } | |
| if len(ques_type_dict) > 0: | |
| result["perQuestionType"] = { | |
| ques_type: 100 * sum(accuracy_list) / len(accuracy_list) | |
| for ques_type, accuracy_list in ques_type_dict.items() | |
| } | |
| return result | |