-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathstyle_transfer.py
78 lines (60 loc) · 2.81 KB
/
style_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import torch
from generator import Generator
from utils import image_preprocessing, vgg_normalization
class StyleTransfer(Generator):
"""Style transfer model.
"""
def __init__(self, style_image, content_image, observed_layers, n_bins=128):
"""
:param content_image: content image
:type content_image: PIL Image object
"""
super().__init__(style_image, observed_layers, n_bins=n_bins)
# input image
self.content_tensor = image_preprocessing(content_image)
self.normalized_content_batch = vgg_normalization(
self.content_tensor).unsqueeze(0)
self.content_batch = self.content_tensor.unsqueeze(0)
def transfer(self, n_passes=5, content_strength=0.5):
"""Style transfer
:param n_passes: number of global passes, defaults to 5
:type n_passes: int, optional
:param content_strength: content strength, defaults to 0.5
:type content_strength: float, optional
:return: generated images layer by layer, step by step
:rtype: list
"""
self.n_passes = n_passes
pass_generated_images = []
# initialize with noise
self.target_tensor = torch.randn_like(self.source_tensor)
for global_pass in range(n_passes):
print(f'global pass {global_pass}')
for layer_name, layer_information in self.observed_layers.items():
print(f'layer {layer_name}')
# forward pass on source image
self.encoder(self.normalized_source_batch)
source_layer = self.encoder_layers[layer_name]
# forward pass on content image
self.encoder(self.normalized_content_batch)
content_layer = self.encoder_layers[layer_name]
# forward pass on target image
target_batch = vgg_normalization(
self.target_tensor).unsqueeze(0)
self.encoder(target_batch)
target_layer = self.encoder_layers[layer_name]
# transport
target_layer = self.optimal_transport(layer_name,
source_layer.squeeze(), target_layer.squeeze())
target_layer = target_layer.view_as(source_layer)
# feature style transfer
target_layer += content_strength * \
(content_layer - target_layer)
# decode
decoder = layer_information['decoder']
self.target_tensor = decoder(target_layer).squeeze()
generated_image = np.transpose(
self.target_tensor.numpy(), (1, 2, 0)).copy()
pass_generated_images.append(generated_image)
return pass_generated_images