| from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred | |
| from torch.utils.data import DataLoader | |
| data_dict = { | |
| 'ETTh1': Dataset_ETT_hour, | |
| 'ETTh2': Dataset_ETT_hour, | |
| 'ETTm1': Dataset_ETT_minute, | |
| 'ETTm2': Dataset_ETT_minute, | |
| 'custom': Dataset_Custom, | |
| } | |
| def data_provider(args, flag): | |
| Data = data_dict[args.data] | |
| timeenc = 0 if args.embed != 'timeF' else 1 | |
| train_only = args.train_only | |
| if flag == 'test': | |
| shuffle_flag = False | |
| drop_last = False | |
| batch_size = args.batch_size | |
| freq = args.freq | |
| elif flag == 'pred': | |
| shuffle_flag = False | |
| drop_last = False | |
| batch_size = 1 | |
| freq = args.freq | |
| Data = Dataset_Pred | |
| else: | |
| shuffle_flag = True | |
| drop_last = True | |
| batch_size = args.batch_size | |
| freq = args.freq | |
| data_set = Data( | |
| root_path=args.root_path, | |
| data_path=args.data_path, | |
| flag=flag, | |
| size=[args.seq_len, args.label_len, args.pred_len], | |
| features=args.features, | |
| target=args.target, | |
| timeenc=timeenc, | |
| freq=freq, | |
| train_only=train_only | |
| ) | |
| print(flag, len(data_set)) | |
| data_loader = DataLoader( | |
| data_set, | |
| batch_size=batch_size, | |
| shuffle=shuffle_flag, | |
| num_workers=args.num_workers, | |
| drop_last=drop_last) | |
| return data_set, data_loader | |