Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| from torch.utils.data import DataLoader | |
| import os | |
| import json | |
| from PIL import Image | |
| class MS_COCO_dataset(Dataset): | |
| def __init__(self, base_dir, annotation_file=None): | |
| self.data= [] | |
| self.img_dir = base_dir + '/images' | |
| self.annotation_file = base_dir + annotation_file | |
| with open(self.annotation_file, 'r') as file: | |
| for line in file: | |
| self.data.append(json.loads(line)) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| # Extract the relevant info from the JSONL entry | |
| img_name = os.path.join(self.img_dir, f"{self.data[idx]['image_name']}") | |
| caption = self.data[idx]['caption'] | |
| sample_id = self.data[idx]['image_id'] | |
| # Load the image using PIL | |
| img = Image.open(img_name) | |
| return {"id": sample_id, | |
| "image": img, | |
| "caption": caption | |
| } | |
| class COCO_CF_dataset(Dataset): | |
| def __init__(self, base_dir): | |
| self.data= [] | |
| self.img_dir = base_dir + '/images' | |
| self.annotation_file = base_dir + "/examples.jsonl" | |
| with open(self.annotation_file, 'r') as file: | |
| for line in file: | |
| self.data.append(json.loads(line)) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| # Extract the relevant info from the JSONL entry | |
| img_0_name = os.path.join(self.img_dir, f"{self.data[idx]['image_0']}.jpg") | |
| img_1_name = os.path.join(self.img_dir, f"{self.data[idx]['image_1']}.jpg") | |
| caption_0 = self.data[idx]['caption_0'] | |
| caption_1 = self.data[idx]['caption_1'] | |
| sample_id = self.data[idx]['id'] | |
| # Load the image using PIL | |
| img_0 = Image.open(img_0_name) | |
| img_1 = Image.open(img_1_name) | |
| return {"id": sample_id, | |
| "caption_0": caption_0, | |
| "caption_1": caption_1, | |
| "image_0": img_0, | |
| "image_1": img_1} | |
| def custom_collate_fn(batch): | |
| collated_batch = {} | |
| for key in batch[0].keys(): | |
| collated_batch[key] = [item[key] for item in batch] | |
| return collated_batch | |
| if __name__ == "__main__": | |
| base_dir = '/home/htc/kchitranshi/SCRATCH/MS_COCO/' | |
| data = MS_COCO_dataset(base_dir=base_dir) | |
| data_loader = DataLoader(data, batch_size=10,collate_fn=custom_collate_fn) | |
| for batch in data_loader: | |
| print(batch) | |
| break | |