updated to check for existence of style transferred image
Browse files
test.py
CHANGED
|
@@ -19,6 +19,7 @@ parser.add_argument('--style_dir', type=str, help='Content image folder path')
|
|
| 19 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
| 20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
|
| 22 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
| 23 |
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
| 24 |
args = parser.parse_args()
|
|
@@ -71,13 +72,12 @@ def main():
|
|
| 71 |
assert len(style_pths) > 0, 'Failed to load style image'
|
| 72 |
|
| 73 |
# Prepare directory for saving results
|
| 74 |
-
|
| 75 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 76 |
|
| 77 |
# Load AdaIN model
|
| 78 |
-
vgg = torch.load('vgg_normalized.pth')
|
| 79 |
model = AdaINNet(vgg).to(device)
|
| 80 |
-
model.decoder.load_state_dict(torch.load(args.decoder_weight))
|
| 81 |
model.eval()
|
| 82 |
|
| 83 |
# Prepare image transform
|
|
@@ -95,14 +95,27 @@ def main():
|
|
| 95 |
|
| 96 |
for content_pth in content_pths:
|
| 97 |
content_img = Image.open(content_pth)
|
|
|
|
|
|
|
| 98 |
content_tensor = t(content_img).unsqueeze(0).to(device)
|
| 99 |
|
| 100 |
if args.grid_pth:
|
| 101 |
imgs.append(content_img)
|
| 102 |
|
| 103 |
for style_pth in style_pths:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
style_tensor = t(
|
| 106 |
|
| 107 |
# Linear Histogram Matching if needed
|
| 108 |
if args.color_control:
|
|
@@ -122,9 +135,6 @@ def main():
|
|
| 122 |
times.append(toc-tic)
|
| 123 |
|
| 124 |
# Save image
|
| 125 |
-
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha)
|
| 126 |
-
if args.color_control: out_pth += '_colorcontrol'
|
| 127 |
-
out_pth += content_pth.suffix
|
| 128 |
save_image(out_tensor, out_pth)
|
| 129 |
|
| 130 |
if args.grid_pth:
|
|
|
|
| 19 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
| 20 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 21 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 22 |
+
parser.add_argument('--output_dir', type=str, default="results")
|
| 23 |
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
|
| 24 |
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
| 25 |
args = parser.parse_args()
|
|
|
|
| 72 |
assert len(style_pths) > 0, 'Failed to load style image'
|
| 73 |
|
| 74 |
# Prepare directory for saving results
|
| 75 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
| 76 |
|
| 77 |
# Load AdaIN model
|
| 78 |
+
vgg = torch.load('vgg_normalized.pth', weights_only=False)
|
| 79 |
model = AdaINNet(vgg).to(device)
|
| 80 |
+
model.decoder.load_state_dict(torch.load(args.decoder_weight, weights_only=False))
|
| 81 |
model.eval()
|
| 82 |
|
| 83 |
# Prepare image transform
|
|
|
|
| 95 |
|
| 96 |
for content_pth in content_pths:
|
| 97 |
content_img = Image.open(content_pth)
|
| 98 |
+
if not content_img.mode == "RGB":
|
| 99 |
+
content_img = content_img.convert("RGB")
|
| 100 |
content_tensor = t(content_img).unsqueeze(0).to(device)
|
| 101 |
|
| 102 |
if args.grid_pth:
|
| 103 |
imgs.append(content_img)
|
| 104 |
|
| 105 |
for style_pth in style_pths:
|
| 106 |
+
|
| 107 |
+
# check if style transferred image exists already
|
| 108 |
+
out_pth = os.path.join(args.output_dir, content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix)
|
| 109 |
+
if os.path.isfile(out_pth):
|
| 110 |
+
print("Skipping existing file")
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
style_img = Image.open(style_pth)
|
| 114 |
+
|
| 115 |
+
if not style_img.mode == "RGB":
|
| 116 |
+
style_img = style_img.convert("RGB")
|
| 117 |
|
| 118 |
+
style_tensor = t(style_img).unsqueeze(0).to(device)
|
| 119 |
|
| 120 |
# Linear Histogram Matching if needed
|
| 121 |
if args.color_control:
|
|
|
|
| 135 |
times.append(toc-tic)
|
| 136 |
|
| 137 |
# Save image
|
|
|
|
|
|
|
|
|
|
| 138 |
save_image(out_tensor, out_pth)
|
| 139 |
|
| 140 |
if args.grid_pth:
|