Spaces:
Build error
Build error
| import os | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| import time | |
| import torchvision | |
| import argparse | |
| from models.SCET import SCET | |
| def inference_img(img_path,Net): | |
| low_image = Image.open(img_path).convert('RGB') | |
| enhance_transforms = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| with torch.no_grad(): | |
| low_image = enhance_transforms(low_image) | |
| low_image = low_image.unsqueeze(0) | |
| start = time.time() | |
| restored2 = Net(low_image) | |
| end = time.time() | |
| return restored2,end-start | |
| if __name__ == '__main__': | |
| parser=argparse.ArgumentParser() | |
| parser.add_argument('--test_path',type=str,required=True,help='Path to test') | |
| parser.add_argument('--save_path',type=str,required=True,help='Path to save') | |
| parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint') | |
| parser.add_argument('--scale',type=int,default=4,help='scale factor') | |
| opt = parser.parse_args() | |
| if not os.path.isdir(opt.save_path): | |
| os.mkdir(opt.save_path) | |
| if opt.scale == 3: | |
| Net = SCET(63, 128, opt.scale) | |
| else: | |
| Net = SCET(64, 128, opt.scale) | |
| Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu'))) | |
| Net=Net.eval() | |
| image=opt.test_path | |
| print(image) | |
| restored2,time_num=inference_img(image,Net) | |
| torchvision.utils.save_image(restored2,opt.save_path+'output.png') | |