Robust_MMFM / vlm_eval /create_clip_dataset.py
KC123hello's picture
Upload Files
fc0ff8f verified
raw
history blame
2.75 kB
import json
import torch
import numpy as np
import random
def main():
# Intialising seeds for data
data_seeds = [i for i in range(107,122)]
ms_coco_base_anno_path = "./clip_train_datasets/MS_COCO/ms_coco_captions.json"
attack_base_anno_path = "./clip_train_datasets/COCO_CF/examples.jsonl"
data_names = ["base","medium","all"]
ms_coco_array = []
attack_array = []
with open(ms_coco_base_anno_path, 'r') as file:
for line in file:
ms_coco_array.append(json.loads(line))
with open(attack_base_anno_path, 'r') as file:
for line in file:
attack_array.append(json.loads(line))
for data_name in data_names:
for data_seed in data_seeds:
if data_name == "base":
num_ms_coco_samples = 8705
num_attacks_samples = 4353 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 8706 in total.
elif data_name == "medium":
num_ms_coco_samples = 17410
num_attacks_samples = int(0.75 * 17410) # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 26115 in total.
elif data_name == "all":
num_ms_coco_samples = 17410
num_attacks_samples = 17410 # These many pairs of samples with their counterfactuals or adv attack samples. Effectively 34820 in total.
np.random.seed(data_seed)
ms_coco_rand_indices = np.random.choice(len(ms_coco_array), num_ms_coco_samples, replace=False)
attack_rand_indices = np.random.choice(len(attack_array), num_attacks_samples, replace=False)
ms_coco_samples = [ms_coco_array[int(i)] for i in ms_coco_rand_indices]
attack_samples = [attack_array[int(i)] for i in attack_rand_indices]
attack_samples = [{"image_id": batch["id"], "image_name": batch[f"image_{i}"] + ".jpg", "caption": batch[f"caption_{i}"]} for batch in attack_samples for i in range(2)]
random.seed(42)
combined_dataset = ms_coco_samples + attack_samples
random.shuffle(combined_dataset)
if data_name != 'all':
with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}_data_seed_{data_seed}.json", 'w') as file:
for line in combined_dataset:
file.write(json.dumps(line) + '\n')
else:
with open(f"./clip_train_datasets/MS_COCO_APGD_4/json_files/data_name_{data_name}.json", 'w') as file:
for line in combined_dataset:
file.write(json.dumps(line) + '\n')
if __name__ == "__main__":
main()