move linear histogram matching to utils.py
Browse files- .gitignore +5 -0
- test.py +1 -28
- utils.py +27 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#Ignore __pycache__
|
| 2 |
+
/__pycache__/
|
| 3 |
+
|
| 4 |
+
#Ignore results
|
| 5 |
+
/results/
|
test.py
CHANGED
|
@@ -8,7 +8,7 @@ from AdaIN import AdaINNet
|
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
from torchvision.transforms import ToPILImage
|
| 11 |
-
from utils import adaptive_instance_normalization, grid_image, transform, Range
|
| 12 |
from glob import glob
|
| 13 |
|
| 14 |
parser = argparse.ArgumentParser()
|
|
@@ -55,33 +55,6 @@ def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
|
|
| 55 |
return decoder(mix_enc)
|
| 56 |
|
| 57 |
|
| 58 |
-
def linear_histogram_matching(content_tensor, style_tensor):
|
| 59 |
-
"""
|
| 60 |
-
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
| 61 |
-
|
| 62 |
-
Args:
|
| 63 |
-
content_tensor (torch.FloatTensor): Content image
|
| 64 |
-
style_tensor (torch.FloatTensor): Style Image
|
| 65 |
-
|
| 66 |
-
Return:
|
| 67 |
-
style_tensor (torch.FloatTensor): histogram matched Style Image
|
| 68 |
-
"""
|
| 69 |
-
#for batch
|
| 70 |
-
for b in range(len(content_tensor)):
|
| 71 |
-
std_ct = []
|
| 72 |
-
std_st = []
|
| 73 |
-
mean_ct = []
|
| 74 |
-
mean_st = []
|
| 75 |
-
#for channel
|
| 76 |
-
for c in range(len(content_tensor[b])):
|
| 77 |
-
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
| 78 |
-
mean_ct.append(torch.mean(content_tensor[b][c]))
|
| 79 |
-
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
| 80 |
-
mean_st.append(torch.mean(style_tensor[b][c]))
|
| 81 |
-
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
| 82 |
-
return style_tensor
|
| 83 |
-
|
| 84 |
-
|
| 85 |
def main():
|
| 86 |
# Read content images and style images
|
| 87 |
if args.content_image:
|
|
|
|
| 8 |
from PIL import Image
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
from torchvision.transforms import ToPILImage
|
| 11 |
+
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
|
| 12 |
from glob import glob
|
| 13 |
|
| 14 |
parser = argparse.ArgumentParser()
|
|
|
|
| 55 |
return decoder(mix_enc)
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def main():
|
| 59 |
# Read content images and style images
|
| 60 |
if args.content_image:
|
utils.py
CHANGED
|
@@ -74,6 +74,33 @@ def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'):
|
|
| 74 |
plt.savefig(save_pth)
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class TrainSet(Dataset):
|
| 78 |
"""
|
| 79 |
Build Training dataset
|
|
|
|
| 74 |
plt.savefig(save_pth)
|
| 75 |
|
| 76 |
|
| 77 |
+
def linear_histogram_matching(content_tensor, style_tensor):
|
| 78 |
+
"""
|
| 79 |
+
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
content_tensor (torch.FloatTensor): Content image
|
| 83 |
+
style_tensor (torch.FloatTensor): Style Image
|
| 84 |
+
|
| 85 |
+
Return:
|
| 86 |
+
style_tensor (torch.FloatTensor): histogram matched Style Image
|
| 87 |
+
"""
|
| 88 |
+
#for batch
|
| 89 |
+
for b in range(len(content_tensor)):
|
| 90 |
+
std_ct = []
|
| 91 |
+
std_st = []
|
| 92 |
+
mean_ct = []
|
| 93 |
+
mean_st = []
|
| 94 |
+
#for channel
|
| 95 |
+
for c in range(len(content_tensor[b])):
|
| 96 |
+
std_ct.append(torch.var(content_tensor[b][c],unbiased = False))
|
| 97 |
+
mean_ct.append(torch.mean(content_tensor[b][c]))
|
| 98 |
+
std_st.append(torch.var(style_tensor[b][c],unbiased = False))
|
| 99 |
+
mean_st.append(torch.mean(style_tensor[b][c]))
|
| 100 |
+
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c]
|
| 101 |
+
return style_tensor
|
| 102 |
+
|
| 103 |
+
|
| 104 |
class TrainSet(Dataset):
|
| 105 |
"""
|
| 106 |
Build Training dataset
|