Spaces:
Running
Running
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import random | |
| import pickle | |
| from os.path import join | |
| from os.path import isfile | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import ( | |
| Compose, | |
| RandomCrop, | |
| CenterCrop, | |
| RandomHorizontalFlip, | |
| ToTensor, | |
| ) | |
| import time | |
| from torchvision.transforms import GaussianBlur | |
| from torchvision import transforms | |
| from pathlib import Path | |
| import json | |
| from tqdm import tqdm | |
| import multiprocessing as mp | |
| import ctypes | |
| def normalize(lat, lon): | |
| """Used to put all lat lon inside ±90 and ±180.""" | |
| lat = (lat + 90) % 360 - 90 | |
| if lat > 90: | |
| lat = 180 - lat | |
| lon += 180 | |
| lon = (lon + 180) % 360 - 180 | |
| return lat, lon | |
| def collate_fn(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| keys = list(batch[0].keys()) | |
| if "weight" in batch[0].keys(): | |
| keys.remove("weight") | |
| output = {} | |
| for key in [ | |
| "idx", | |
| "unique_country", | |
| "unique_region", | |
| "unique_sub-region", | |
| "unique_city", | |
| "img_idx", | |
| "text", | |
| ]: | |
| if key in keys: | |
| idx = [x[key] for x in batch] | |
| output[key] = idx | |
| keys.remove(key) | |
| if "img" in keys and isinstance(batch[0]["img"], Image.Image): | |
| output["img"] = [x["img"] for x in batch] | |
| keys.remove("img") | |
| for key in keys: | |
| if not ("text" in key): | |
| output[key] = torch.stack([x[key] for x in batch]) | |
| return output | |
| def collate_fn_streetclip(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| keys = list(batch[0].keys()) | |
| if "weight" in batch[0].keys(): | |
| keys.remove("weight") | |
| output = {} | |
| for key in [ | |
| "idx", | |
| "unique_country", | |
| "unique_region", | |
| "unique_sub-region", | |
| "unique_city", | |
| "img_idx", | |
| "img", | |
| "text", | |
| ]: | |
| if key in keys: | |
| idx = [x[key] for x in batch] | |
| output[key] = idx | |
| keys.remove(key) | |
| for key in keys: | |
| if not ("text" in key): | |
| output[key] = torch.stack([x[key] for x in batch]) | |
| return output | |
| def collate_fn_denstity(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| keys = list(batch[0].keys()) | |
| if "weight" in batch[0].keys(): | |
| keys.remove("weight") | |
| # Sample indices based on the weights | |
| weights = np.array([x["weight"] for x in batch]) | |
| normalized_weights = weights / np.sum(weights) | |
| sampled_indices = np.random.choice( | |
| len(batch), size=len(batch), p=normalized_weights, replace=True | |
| ) | |
| output = {} | |
| for key in [ | |
| "idx", | |
| "unique_country", | |
| "unique_region", | |
| "unique_sub-region", | |
| "unique_city", | |
| "img_idx", | |
| "text", | |
| ]: | |
| if key in keys: | |
| idx = [batch[i][key] for i in sampled_indices] | |
| output[key] = idx | |
| keys.remove(key) | |
| for key in keys: | |
| if not ("text" in key): | |
| output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
| return output | |
| def collate_fn_streetclip_denstity(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| keys = list(batch[0].keys()) | |
| if "weight" in batch[0].keys(): | |
| keys.remove("weight") | |
| # Sample indices based on the weights | |
| weights = np.array([x["weight"] for x in batch]) | |
| normalized_weights = weights / np.sum(weights) | |
| sampled_indices = np.random.choice( | |
| len(batch), size=len(batch), p=normalized_weights, replace=True | |
| ) | |
| output = {} | |
| for key in [ | |
| "idx", | |
| "unique_country", | |
| "unique_region", | |
| "unique_sub-region", | |
| "unique_city", | |
| "img_idx", | |
| "img", | |
| "text", | |
| ]: | |
| if key in keys: | |
| idx = [batch[i][key] for i in sampled_indices] | |
| output[key] = idx | |
| keys.remove(key) | |
| for key in keys: | |
| if not ("text" in key): | |
| output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
| return output | |
| def collate_fn_contrastive(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| output = collate_fn(batch) | |
| pos_img = torch.stack([x["pos_img"] for x in batch]) | |
| output["pos_img"] = pos_img | |
| return output | |
| def collate_fn_contrastive_density(batch): | |
| """Collate function for the dataloader. | |
| Args: | |
| batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label" | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| keys = list(batch[0].keys()) | |
| if "weight" in batch[0].keys(): | |
| keys.remove("weight") | |
| # Sample indices based on the weights | |
| weights = np.array([x["weight"] for x in batch]) | |
| normalized_weights = weights / np.sum(weights) | |
| sampled_indices = np.random.choice( | |
| len(batch), size=len(batch), p=normalized_weights, replace=True | |
| ) | |
| output = {} | |
| for key in [ | |
| "idx", | |
| "unique_country", | |
| "unique_region", | |
| "unique_sub-region", | |
| "unique_city", | |
| "img_idx", | |
| ]: | |
| if key in keys: | |
| idx = [batch[i][key] for i in sampled_indices] | |
| output[key] = idx | |
| keys.remove(key) | |
| for key in keys: | |
| if not ("text" in key): | |
| output[key] = torch.stack([batch[i][key] for i in sampled_indices]) | |
| return output | |
| class iNaturalist(Dataset): | |
| def __init__( | |
| self, | |
| path, | |
| transforms, | |
| split="train", | |
| output_type="image", | |
| embedding_name="dinov2", | |
| ): | |
| super().__init__() | |
| self.split = split | |
| with open(Path(path) / f"{split}.json", "r") as f: | |
| self.metadata = json.load(f) | |
| self.metadata = [ | |
| datapoint | |
| for datapoint in self.metadata["images"] | |
| if "latitude" in datapoint and datapoint["latitude"] is not None | |
| ] | |
| self.path = path | |
| self.transforms = transforms | |
| self.output_type = output_type | |
| self.embedding_name = embedding_name | |
| self.collate_fn = collate_fn | |
| def __getitem__(self, i): | |
| output = {} | |
| if "image" in self.output_type: | |
| image_path = Path(self.path) / "images" / self.metadata[i]["file_name"] | |
| img = self.transforms(Image.open(image_path)) | |
| output["img"] = img | |
| if "emb" in self.output_type: | |
| emb_path = ( | |
| Path(self.path) | |
| / "embeddings" | |
| / self.embedding_name | |
| / self.metadata[i]["file_name"].replace(".jpg", ".npy") | |
| ) | |
| output["emb"] = torch.tensor(np.load(emb_path)) | |
| lat, lon = normalize( | |
| self.metadata[i]["latitude"], self.metadata[i]["longitude"] | |
| ) | |
| output["gps"] = torch.tensor( | |
| [np.radians(lat), np.radians(lon)], dtype=torch.float | |
| ) | |
| output["idx"] = i | |
| output["img_idx"] = self.metadata[i]["id"] | |
| return output | |
| def __len__(self): | |
| return len(self.metadata) | |
| class OSV5M(Dataset): | |
| csv_dtype = {"category": str, "country": str, "city": str} # Don't remove. | |
| def __init__( | |
| self, | |
| path, | |
| transforms, | |
| split="train", | |
| class_name=None, | |
| aux_data=[], | |
| is_baseline=False, | |
| areas=["country", "region", "sub-region", "city"], | |
| streetclip=False, | |
| suff="", | |
| blur=False, | |
| output_type="image", | |
| embedding_name="dinov2", | |
| ): | |
| """Initializes the dataset. | |
| Args: | |
| path (str): path to the dataset | |
| transforms (torchvision.transforms): transforms to apply to the images | |
| split (str): split to use (train, val, test) | |
| class_name (str): category to use (e.g. "city") | |
| aux_data (list of str): auxilliary datas to use | |
| areas (list of str): regions to perform accuracy | |
| streetclip (bool): if the model is streetclip, do not use transform | |
| suff (str): suffix of test csv | |
| blur (bool): blur bottom of images or not | |
| output_type (str): type of output (image or emb) | |
| """ | |
| self.suff = suff | |
| self.path = path | |
| self.aux = len(aux_data) > 0 | |
| self.aux_list = aux_data | |
| self.split = split | |
| if split == "select": | |
| self.df = self.load_split(split) | |
| split = "test" | |
| else: | |
| self.df = self.load_split(split) | |
| self.split = split | |
| if "image" in output_type: | |
| self.image_data_folder = join( | |
| path, | |
| "images", | |
| ("train" if split == "val" else split), | |
| ) | |
| self.image_dict_names = {} | |
| for root, _, files in os.walk(self.image_data_folder): | |
| for file in files: | |
| self.image_dict_names[file] = os.path.join(root, file) | |
| if "emb" in output_type: | |
| self.emb_data_folder = join( | |
| path, | |
| "embeddings", | |
| embedding_name, | |
| ("train" if split == "val" else split), | |
| ) | |
| self.emb_dict_names = {} | |
| for root, _, files in os.walk(self.emb_data_folder): | |
| for file in files: | |
| self.emb_dict_names[file] = os.path.join(root, file) | |
| self.output_type = output_type | |
| self.is_baseline = is_baseline | |
| if self.aux: | |
| self.aux_data = {} | |
| for col in self.aux_list: | |
| if col in ["land_cover", "climate", "soil"]: | |
| self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float) | |
| if col == "climate": | |
| for i in range(31): | |
| if not (i in list(self.aux_data[col].columns)): | |
| self.aux_data[col][i] = 0 | |
| desired_order = [i for i in range(31)] | |
| desired_order.remove(20) | |
| self.aux_data[col] = self.aux_data[col][desired_order] | |
| else: | |
| self.aux_data[col] = self.df[col].apply(lambda x: [x]) | |
| self.areas = ["_".join(["unique", area]) for area in areas] | |
| if class_name is None: | |
| self.class_name = class_name | |
| elif "quadtree" in class_name: | |
| self.class_name = class_name | |
| else: | |
| self.class_name = "_".join(["unique", class_name]) | |
| ex = self.extract_classes(self.class_name) | |
| self.df = self.df[ | |
| ["id", "latitude", "longitude", "weight"] + self.areas + ex | |
| ].fillna("NaN") | |
| if self.class_name in self.areas: | |
| self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"] | |
| self.transforms = transforms | |
| self.collate_fn = collate_fn | |
| self.collate_fn_density = collate_fn_denstity | |
| self.blur = blur | |
| self.streetclip = streetclip | |
| if self.streetclip: | |
| self.collate_fn = collate_fn_streetclip | |
| self.collate_fn_density = collate_fn_streetclip_denstity | |
| def load_split(self, split): | |
| """Returns a new dataset with the given split.""" | |
| start_time = time.time() | |
| if split == "test": | |
| df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype) | |
| # extract coord | |
| longitude = df["longitude"].values | |
| latitude = df["latitude"].values | |
| # Create bins | |
| num_bins = 100 | |
| lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
| lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
| # compute density and weights | |
| hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
| weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
| normalized_weights = weights / np.sum(weights) | |
| df["weight"] = normalized_weights | |
| return df | |
| elif split == "select": | |
| df = pd.read_csv(join(self.path, "select.csv"), dtype=self.csv_dtype) | |
| # extract coord | |
| longitude = df["longitude"].values | |
| latitude = df["latitude"].values | |
| # Create bins | |
| num_bins = 100 | |
| lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
| lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
| # compute density and weights | |
| hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
| weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
| normalized_weights = weights / np.sum(weights) | |
| df["weight"] = normalized_weights | |
| return df | |
| else: | |
| if len(self.suff) == 0: | |
| df = pd.read_csv(join(self.path, "train.csv"), dtype=self.csv_dtype) | |
| else: | |
| df = pd.read_csv( | |
| join(self.path, "train" + "_" + self.suff + ".csv"), | |
| dtype=self.csv_dtype, | |
| ) | |
| # extract coord | |
| longitude = df["longitude"].values | |
| latitude = df["latitude"].values | |
| # Create bins | |
| num_bins = 100 | |
| lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins) | |
| lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins) | |
| # compute density and weights | |
| hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins]) | |
| weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75) | |
| normalized_weights = weights / np.sum(weights) | |
| df["weight"] = normalized_weights | |
| test_df = df.sample( | |
| n=int(0.1 * len(df)), | |
| weights=normalized_weights, | |
| replace=False, | |
| random_state=42, | |
| ) | |
| end_time = time.time() | |
| print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds") | |
| if split == "val": | |
| return test_df | |
| else: | |
| return df.drop(test_df.index) | |
| def extract_classes(self, tag=None): | |
| """Extracts the categories from the dataset.""" | |
| if tag is None: | |
| self.has_labels = False | |
| return [] | |
| splits = ["train", "test"] if self.is_baseline else ["train"] | |
| # splits = ["train", "test"] | |
| print(f"Loading categories from {splits}") | |
| # concatenate all categories from relevant splits to find the unique ones. | |
| self.categories = sorted( | |
| pd.concat( | |
| [pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] | |
| ) | |
| .fillna("NaN") | |
| .unique() | |
| .tolist() | |
| ) | |
| if "NaN" in self.categories: | |
| self.categories.remove("NaN") | |
| if self.split != "test": | |
| self.df = self.df.dropna(subset=[tag]) | |
| # compute the total number of categories - this name is fixed and will be used as a lookup during init | |
| self.num_classes = len(self.categories) | |
| # create a mapping from category to index | |
| self.category_to_index = { | |
| category: i for i, category in enumerate(self.categories) | |
| } | |
| self.has_labels = True | |
| return [tag] | |
| def __getitem__(self, i): | |
| """Returns an item from the dataset. | |
| Args: | |
| i (int): index of the item | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| x = list(self.df.iloc[i]) # id, latitude, longitude, {category} | |
| output = {} | |
| if "image" in self.output_type: | |
| if self.streetclip: | |
| img = Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) | |
| elif self.blur: | |
| img = transforms.ToTensor()( | |
| Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) | |
| ) | |
| u = GaussianBlur(kernel_size=13, sigma=2.0) | |
| bottom_part = img[:, -14:, :].unsqueeze(0) | |
| blurred_bottom = u(bottom_part) | |
| img[:, -14:, :] = blurred_bottom.squeeze() | |
| img = self.transforms(transforms.ToPILImage()(img)) | |
| else: | |
| img = self.transforms( | |
| Image.open(self.image_dict_names[f"{int(x[0])}.jpg"]) | |
| ) | |
| output["img"] = img | |
| if "emb" in self.output_type: | |
| output["emb"] = torch.FloatTensor( | |
| np.load(self.emb_dict_names[f"{int(x[0])}.npy"]) | |
| ) | |
| lat, lon = normalize(x[1], x[2]) | |
| gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) | |
| output.update( | |
| { | |
| "gps": gps, | |
| "idx": i, | |
| "img_idx": int(x[0]), | |
| "weight": x[3], | |
| } | |
| ) | |
| for count, area in enumerate(self.areas): | |
| output[area] = x[ | |
| count + 4 | |
| ] #'country': x[3], 'region': x[4], 'sub-region': x[5], 'city': x[6]} | |
| if self.has_labels: | |
| if x[-1] in self.categories: | |
| output["label"] = torch.LongTensor( | |
| [self.category_to_index[x[-1]]] | |
| ).squeeze(-1) | |
| else: | |
| output["label"] = torch.LongTensor([-1]).squeeze(-1) | |
| if self.aux: | |
| for col in self.aux_list: | |
| output[col] = torch.FloatTensor(self.aux_data[col].iloc[i]) | |
| return output | |
| def __len__(self): | |
| return len(self.df) | |
| class ContrastiveOSV5M(OSV5M): | |
| def __init__( | |
| self, | |
| path, | |
| transforms, | |
| split="train", | |
| class_name=None, | |
| aux_data=[], | |
| class_name2=None, | |
| blur=False, | |
| ): | |
| """ | |
| class_name2 (str): if not None, we do contrastive an other class than the one specified for classif | |
| """ | |
| super().__init__( | |
| path, | |
| transforms, | |
| split=split, | |
| class_name=class_name, | |
| aux_data=aux_data, | |
| blur=blur, | |
| ) | |
| self.add_label = False | |
| if not (class_name2 is None) and split != "test" and split != "select": | |
| self.add_label = True | |
| self.class_name = class_name2 | |
| self.extract_classes_contrastive(tag=class_name2) | |
| self.df = self.df.reset_index(drop=True) | |
| self.dict_classes = { | |
| value: indices.tolist() | |
| for value, indices in self.df.groupby(self.class_name).groups.items() | |
| } | |
| self.collate_fn = collate_fn_contrastive | |
| self.random_crop = RandomCrop(224) # use when no positive image is available | |
| def sample_positive(self, i): | |
| """ | |
| sample positive image from the same city, country if it is available | |
| otherwise, apply different crop to the image | |
| """ | |
| x = self.df.iloc[i] # id, latitude, longitude, {category} | |
| class_name = x[self.class_name] | |
| idxs = self.dict_classes[class_name] | |
| idxs.remove(i) | |
| if len(idxs) > 0: | |
| idx = random.choice(idxs) | |
| x = self.df.iloc[idx] | |
| pos_img = self.transforms( | |
| Image.open(self.dict_names[f"{int(x['id'])}.jpg"]) | |
| ) | |
| else: | |
| pos_img = self.random_crop( | |
| self.transforms(Image.open(self.dict_names[f"{int(x['id'])}.jpg"])) | |
| ) | |
| return pos_img | |
| def extract_classes_contrastive(self, tag=None): | |
| """Extracts the categories from the dataset.""" | |
| if tag is None: | |
| self.has_labels = False | |
| return [] | |
| splits = ["train", "test"] if self.is_baseline else ["train"] | |
| # splits = ["train", "test"] | |
| print(f"Loading categories from {splits}") | |
| # concatenate all categories from relevant splits to find the unique ones. | |
| categories = sorted( | |
| pd.concat( | |
| [pd.read_csv(join(self.path, f"{split}.csv"))[tag] for split in splits] | |
| ) | |
| .fillna("NaN") | |
| .unique() | |
| .tolist() | |
| ) | |
| # create a mapping from category to index | |
| self.contrastive_category_to_index = { | |
| category: i for i, category in enumerate(categories) | |
| } | |
| def __getitem__(self, i): | |
| output = super().__getitem__(i) | |
| pos_img = self.sample_positive(i) | |
| output["pos_img"] = pos_img | |
| if self.add_label: | |
| output["label_contrastive"] = torch.LongTensor( | |
| [self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]] | |
| ).squeeze(-1) | |
| return output | |
| class TextContrastiveOSV5M(OSV5M): | |
| def __init__( | |
| self, | |
| path, | |
| transforms, | |
| split="train", | |
| class_name=None, | |
| aux_data=[], | |
| blur=False, | |
| ): | |
| super().__init__( | |
| path, | |
| transforms, | |
| split=split, | |
| class_name=class_name, | |
| aux_data=aux_data, | |
| blur=blur, | |
| ) | |
| self.df = self.df.reset_index(drop=True) | |
| def get_text(self, i): | |
| """ | |
| sample positive image from the same city, country if it is available | |
| otherwise, apply different crop to the image | |
| """ | |
| x = self.df.iloc[i] # id, latitude, longitude, {category} | |
| l = [ | |
| name.split("_")[-1] | |
| for name in [ | |
| x["unique_city"], | |
| x["unique_sub-region"], | |
| x["unique_region"], | |
| x["unique_country"], | |
| ] | |
| ] | |
| pre = False | |
| sentence = "An image of " | |
| if l[0] != "NaN": | |
| sentence += "the city of " | |
| sentence += l[0] | |
| pre = True | |
| if l[1] != "NaN": | |
| if pre: | |
| sentence += ", in " | |
| sentence += "the area of " | |
| sentence += l[1] | |
| pre = True | |
| if l[2] != "NaN": | |
| if pre: | |
| sentence += ", in " | |
| sentence += "the region of " | |
| sentence += l[2] | |
| pre = True | |
| if l[3] != "NaN": | |
| if pre: | |
| sentence += ", in " | |
| sentence += l[3] | |
| return sentence | |
| def __getitem__(self, i): | |
| output = super().__getitem__(i) | |
| output["text"] = self.get_text(i) | |
| return output | |
| import os | |
| import json | |
| class Baseline(Dataset): | |
| def __init__( | |
| self, | |
| path, | |
| which, | |
| transforms, | |
| ): | |
| """Initializes the dataset. | |
| Args: | |
| path (str): path to the dataset | |
| which (str): which baseline to use (im2gps, im2gps3k) | |
| transforms (torchvision.transforms): transforms to apply to the images | |
| """ | |
| baselines = { | |
| "im2gps": self.load_im2gps, | |
| "im2gps3k": self.load_im2gps, | |
| "yfcc4k": self.load_yfcc4k, | |
| } | |
| self.path = path | |
| self.samples = baselines[which]() | |
| self.transforms = transforms | |
| self.collate_fn = collate_fn | |
| self.class_name = which | |
| def load_im2gps( | |
| self, | |
| ): | |
| json_path = join(self.path, "info.json") | |
| with open(json_path) as f: | |
| data = json.load(f) | |
| samples = [] | |
| for f in os.listdir(join(self.path, "images")): | |
| if len(data[f]): | |
| lat = float(data[f][-4].replace("latitude: ", "")) | |
| lon = float(data[f][-3].replace("longitude: ", "")) | |
| samples.append((f, lat, lon)) | |
| return samples | |
| def load_yfcc4k( | |
| self, | |
| ): | |
| samples = [] | |
| with open(join(self.path, "info.txt")) as f: | |
| lines = f.readlines() | |
| for line in lines: | |
| x = line.split("\t") | |
| f, lon, lat = x[1], x[12], x[13] | |
| samples.append((f + ".jpg", float(lat), float(lon))) | |
| return samples | |
| def __getitem__(self, i): | |
| """Returns an item from the dataset. | |
| Args: | |
| i (int): index of the item | |
| Returns: | |
| dict: dictionary with keys "img", "gps", "idx" and optionally "label" | |
| """ | |
| img_path, lat, lon = self.samples[i] | |
| img = self.transforms( | |
| Image.open(join(self.path, "images", img_path)).convert("RGB") | |
| ) | |
| lat, lon = normalize(lat, lon) | |
| gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0) | |
| return { | |
| "img": img, | |
| "gps": gps, | |
| "idx": i, | |
| } | |
| def __len__(self): | |
| return len(self.samples) | |
| null_transform = lambda x: x | |