cycleGAN实现--tensorflow2.0

cycleGAN 模型属于 GAN 模型的一个变形,在很多情况下,我们无法获得或者很难获得成对的训练数据,cycleGAN 要解决的问题是: seek an algorithm that can learn to translate between domains without paired input-output examples,如下图:

其中,cycleGAN 的网络结构如下图所示:

它包括两个生成器 G、F和两个判别器 Dx、Dy。生成器 G 对输入的图片x 进行变换生成图片 $\hat y$ ,生成器 F 对输入的照片 y 进行变换生成照片 $\hat x$ 。判别器 Dx 对生成器 G生成的图片 $\hat y$ 和真实的图片 y 进行判别,分辨真假;判别器 Dy 对生成器 F 生成的图片 $\hat x$ 和真实的图片 x 进行判别,分辨真伪。

本文所复现的论文地址:https://arxiv.org/abs/1703.10593

1、数据集处理

所用数据集地址:https://people.eecs.berkeley.edu/%7Etaesung_park/CycleGAN/datasets/,(horse<-->zebra)

其中 trainA 中是 horse 照片,trainB 中是 zebra 照片,图片大小为 256*256*3,其处理过程和前几篇博文中的一样,包括放大、裁剪、随机镜像、归一化等操作,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def load_image(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image)
return image # 数值为 [0, 255]

def random_crop(image):
cropped_image = tf.image.random_crop(image, size=[256, 256, 3])
return cropped_image

def random_jitter(image):
# 调整图片大小为 256*256*3-->286*286*3
image = tf.image.resize(image, size=[286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# 随机裁剪图片到尺寸 256*256*3
image = random_crop(image)
# 随机镜像
image = tf.image.random_flip_left_right(image)
return image

def normalize(image):
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1 # 像素值 [-1, 1] 之间
return image

def preprocess_train_image(image_path):
image = load_image(image_path)
image = random_jitter(image)
image = normalize(image)
return image

def preprocess_test_image(image_path):
image = load_image(image_path)
image = normalize(image)
return image
2、定义生成器

这里生成器和 pix2pix 一样,同样采用的是 u-net 结构,论文中说是用 instance normalization 代替 batch normalization,因为我们这里设置的 batch size = 1 ,所以,就一点都没有改动,直接搬过来,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def Generator():
inputs = tf.keras.layers.Input(shape=(256, 256, 3))
down_stack = [
downsample(64, (4, 4), apply_batchnorm=False), # [batch, 128, 128, 64]
downsample(128, (4, 4)), # [batch, 64, 64, 128]
downsample(256, (4, 4)), # [batch, 32, 32, 256]
downsample(512, (4, 4)), # [batch, 16, 16, 512]
downsample(512, (4, 4)), # [batch, 8, 8, 512]
downsample(512, (4, 4)), # [batch, 4, 4, 512]
downsample(512, (4, 4)), # [batch, 2, 2, 512]
downsample(512, (4, 4)), # [batch, 1, 1, 512]
]
up_stack = [
upsample(512, (4, 4), apply_dropout=True), # [batch, 2, 2, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 4, 4, 512]
upsample(512, (4, 4), apply_dropout=True), # [batch, 8, 8, 512]
upsample(512, (4, 4)), # [batch, 16, 16, 512]
upsample(256, (4, 4)), # [batch, 32, 32, 256]
upsample(128, (4, 4)), # [batch, 64, 64, 128]
upsample(64, (4, 4)), # [batch, 128, 128, 64]
]
last_layer = tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=(4, 4), strides=2, padding='same',
kernel_initializer=tf.random_normal_initializer(0., 0.02), activation='tanh') # [batch, 256, 256, 3]
x = inputs # inputs is a tensor with shape [256, 256, 3]
down_outputs = []
for down_layer in down_stack:
x = down_layer(x) # 调用的是 call 方法
down_outputs.append(x)
down_outputs = reversed(down_outputs[:-1]) # 2-->4-->8-->16-->32-->64-->128, 共 7 层
for up_layer, down_output in zip(up_stack, down_outputs):
x = up_layer(x)
x = tf.concat([x, down_output], axis=3)
x = last_layer(x) # [batch, 256, 256, 3]
return tf.keras.Model(inputs=inputs, outputs=x)
3、定义判别器

判别器的结构也是和 pix2pix 中的结构一样(patchGAN),只有一点不同,就是这里的判别器的输入是一张图片,不再是之前的一张图片+条件。 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def Discriminator():
input_image = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
x = input_image # [batch, 256, 256, 3]

down_sample1 = downsample(filters=64, kernel_size=(4, 4), apply_batchnorm=False)(x) # [batch, 128, 128, 64]
down_sample2 = downsample(128, (4, 4))(down_sample1) # [batch, 64, 64, 128]
down_sample3 = downsample(256, (4, 4))(down_sample2) # [batch, 32, 32, 256]

zero_pad1 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(down_sample3) # [batch, 34, 34, 256]
conv = tf.keras.layers.Conv2D(filters=512, kernel_size=(4, 4), strides=1, padding='valid', kernel_initializer=tf.random_normal_initializer(0., 0.02),
use_bias=False)(zero_pad1) # [batch, 31, 31, 512]
batchnorm = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm)

zero_pad2 = tf.keras.layers.ZeroPadding2D(padding=(1, 1))(leaky_relu) # [batch, 33, 33, 512]
output = tf.keras.layers.Conv2D(filters=1, kernel_size=(4, 4), strides=1, kernel_initializer=tf.random_normal_initializer(0., 0.02))(zero_pad2)
# [batch, 30, 30, 1]
return tf.keras.Model(inputs=input_image, outputs=output)
4、定义损失函数

对于生成器来说,要满足以下几个要求:

  • 生成的图片不能够被判别器认出来,即生成图片经过判别器输出的 30*30*1 的矩阵和全 1 矩阵的差距。

  • 对生成器G来说,输入 horse 要输出zebra,但是,输入 zebra 还要输出zebra,即 same_loss;对 F 来说也一样。

  • 循环一致性损失,即 xF(G(x)) – $\hat x$ ,两者的差距要尽可能的小,即 cycle_loss。同样的, yG(F(y)) – $\hat y$ 。

具体代码定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
def generator_loss(disc_generated_output):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output) # generate image is close to 1
return gan_loss

def identity_loss(real_image, same_image):
loss = tf.reduce_mean(tf.abs(real_image-same_image))
return loss

def cycle_loss(real_image, cycle_image):
loss = tf.reduce_mean(tf.abs(real_image-cycle_image))
return loss

对判别器来说,满足能判别真假图就可以了,真实图片的判别输出与全 1 比较,生成图片的判别输出与全 0 比较。代码定义如下:

1
2
3
4
5
def discriminator_loss(disc_real_output, disc_generated_output):
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output) # real image is close to 1
generate_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) # generate image is close to 0
return (real_loss+generate_loss) * 0.5
5、定义优化器及训练过程

训练过程就是: 生成图片—计算损失—计算梯度—更新参数 的过程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
Gx = Generator()
Fy = Generator()
Dx = Discriminator()
Dy = Discriminator()
# 定义优化器, Gx(x-->y), Fy(y-->x), Dx(F(y)和x), Dy(G(x)和y)
Gx_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
Fy_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
Dx_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
Dy_optimizier = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

@tf.function
def train_step(real_x, real_y):
with tf.GradientTape(persistent=True) as tape: # persistent=True 意味着可以多次调用 tape.gradient 函数进行求导
fake_y = Gx(real_x, training=True) # x--G(x)--y
cycle_x = Fy(fake_y, training=True) # x--F(G(x))--x
fake_x = Fy(real_y, training=True) # y--F(y)--x
cycle_y = Gx(fake_x, training=True) # y--G(F(y))--y
same_x = Fy(real_x, training=True) # x--F(x)--x
same_y = Gx(real_y, training=True) # y--G(y)--y

disc_real_x_output = Dx(real_x, training=True)
disc_real_y_output = Dy(real_y, training=True)
disc_fake_x_output = Dx(fake_x, training=True)
disc_fake_y_output = Dy(fake_y, training=True)

# 计算损失
total_cycle_loss = cycle_loss(real_x, cycle_x) + cycle_loss(real_y, cycle_y)
Gx_loss = generator_loss(disc_fake_y_output) + identity_loss(real_y, same_y)*5 + total_cycle_loss*10
Fy_loss = generator_loss(disc_fake_x_output) + identity_loss(real_x, same_x)*5 + total_cycle_loss*10
Dx_loss = discriminator_loss(disc_real_x_output, disc_fake_x_output)
Dy_loss = discriminator_loss(disc_real_y_output, disc_fake_y_output)
# 计算梯度
Gx_gradients = tape.gradient(Gx_loss, Gx.trainable_variables)
Fy_gradients = tape.gradient(Fy_loss, Fy.trainable_variables)
Dx_gradients = tape.gradient(Dx_loss, Dx.trainable_variables)
Dy_gradients = tape.gradient(Dy_loss, Dy.trainable_variables)
# 使用优化器更新模型参数
Gx_optimizier.apply_gradients(zip(Gx_gradients, Gx.trainable_variables))
Fy_optimizier.apply_gradients(zip(Fy_gradients, Fy.trainable_variables))
Dx_optimizier.apply_gradients(zip(Dx_gradients, Dx.trainable_variables))
Dy_optimizier.apply_gradients(zip(Dy_gradients, Dy.trainable_variables))
6、开始训练

训练过程无需成对的训练数据,每次随机从两个数据集 trainAtrainB 中挑选两张图片,执行 train_step 即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def train():
train_X_dir = "dataset/horse2zebra/trainA"
train_Y_dir = "dataset/horse2zebra/trainB"
train_x_images_path = os.listdir(train_X_dir)
train_y_images_path = os.listdir(train_Y_dir)
start = time.time()
for step in range(70000):
horse = preprocess_train_image(os.path.join(train_X_dir, random.choice(train_x_images_path)))
zebra = preprocess_train_image(os.path.join(train_Y_dir, random.choice(train_y_images_path)))
train_step(horse[tf.newaxis,...], zebra[tf.newaxis,...])
if (step+1) % 10000 == 0:
checkpoint.save("training_checkpoint/cycleGAN")
if (step+1) % 700 == 0:
print("step:", step, ", time:", time.time()-start)
show_image(step)
start = time.time()
7、效果展示

在训练过程中我们 save了7个阶段性模型参数,我们使用test 数据集依次查看一下不同阶段的模型效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def show_image(step):
horse_files = os.listdir("dataset/horse2zebra/testA")
zebra_files = os.listdir("dataset/horse2zebra/testB")
real_horse_image = preprocess_test_image(os.path.join("dataset/horse2zebra/testA", random.choice(horse_files)))
real_zebra_image = preprocess_test_image(os.path.join("dataset/horse2zebra/testB", random.choice(zebra_files)))
fake_zebra_image = Gx(real_horse_image[tf.newaxis,...], training=False)[0]
fake_horse_image = Fy(real_zebra_image[tf.newaxis,...], training=False)[0]
plt.figure()
plt.subplot(2, 2, 1)
plt.title("real_horse")
plt.imshow(real_horse_image*0.5+0.5)
plt.subplot(2, 2, 2)
plt.title("fake_zebra")
plt.imshow(fake_zebra_image*0.5+0.5)

plt.subplot(2, 2, 3)
plt.title("real_zebra")
plt.imshow(real_zebra_image*0.5+0.5)
plt.subplot(2, 2, 4)
plt.title("fake_horse")
plt.imshow(fake_horse_image*0.5+0.5)
plt.savefig("cycleGAN_image_save/" + str(step//700) + ".png")
plt.show()
-------------本文结束感谢您的阅读-------------
您的鼓励就是我创作的动力,求打赏买面包~~
0%