add linear histogram matching to test_video.py
Browse files- test_video.py +11 -3
test_video.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
| 4 |
from pathlib import Path
|
| 5 |
from AdaIN import AdaINNet
|
| 6 |
from PIL import Image
|
| 7 |
-
from utils import transform, adaptive_instance_normalization, Range
|
| 8 |
import cv2
|
| 9 |
import imageio
|
| 10 |
import numpy as np
|
|
@@ -17,6 +17,7 @@ parser.add_argument('--style_image', type=str, required=True, help='Style image
|
|
| 17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
| 18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
|
|
|
| 20 |
args = parser.parse_args()
|
| 21 |
|
| 22 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
|
@@ -67,8 +68,10 @@ def main():
|
|
| 67 |
# Prepare output video writer
|
| 68 |
out_dir = './results_video/'
|
| 69 |
os.makedirs(out_dir, exist_ok=True)
|
| 70 |
-
out_pth =
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
| 73 |
|
| 74 |
# Load AdaIN model
|
|
@@ -82,6 +85,7 @@ def main():
|
|
| 82 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
| 83 |
|
| 84 |
|
|
|
|
| 85 |
while content_video.isOpened():
|
| 86 |
ret, content_image = content_video.read()
|
| 87 |
# Failed to read a frame
|
|
@@ -90,6 +94,10 @@ def main():
|
|
| 90 |
|
| 91 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
with torch.no_grad():
|
| 94 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
| 95 |
, model.decoder, args.alpha).cpu().detach().numpy()
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from AdaIN import AdaINNet
|
| 6 |
from PIL import Image
|
| 7 |
+
from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
|
| 8 |
import cv2
|
| 9 |
import imageio
|
| 10 |
import numpy as np
|
|
|
|
| 17 |
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
|
| 18 |
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
|
| 19 |
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
|
| 20 |
+
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
|
| 21 |
args = parser.parse_args()
|
| 22 |
|
| 23 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
|
|
|
| 68 |
# Prepare output video writer
|
| 69 |
out_dir = './results_video/'
|
| 70 |
os.makedirs(out_dir, exist_ok=True)
|
| 71 |
+
out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
|
| 72 |
+
if args.color_control: out_pth += '_colorcontrol'
|
| 73 |
+
out_pth += content_video_pth.suffix
|
| 74 |
+
out_pth = Path(out_pth)
|
| 75 |
writer = imageio.get_writer(out_pth, mode='I', fps=fps)
|
| 76 |
|
| 77 |
# Load AdaIN model
|
|
|
|
| 85 |
style_tensor = t(style_image).unsqueeze(0).to(device)
|
| 86 |
|
| 87 |
|
| 88 |
+
|
| 89 |
while content_video.isOpened():
|
| 90 |
ret, content_image = content_video.read()
|
| 91 |
# Failed to read a frame
|
|
|
|
| 94 |
|
| 95 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
| 96 |
|
| 97 |
+
# Linear Histogram Matching if needed
|
| 98 |
+
if args.color_control:
|
| 99 |
+
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
|
| 100 |
+
|
| 101 |
with torch.no_grad():
|
| 102 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
|
| 103 |
, model.decoder, args.alpha).cpu().detach().numpy()
|