CatDCGAN项目复现与对抗网络初识

作者 CarpVexing 日期 100521 禁止转载



引言

Gan对抗生成网络,不管怎么说,这都算不上是一项新技术。奈何笔者长期所学落后于时代,总算在保研后一段时间能够好好琢磨一下ML的内容。今天先记录下复现CatDCGAN项目的操作流程,并将所参考的网址记录在此处。

CatDCGAN项目基本信息

这是一个利用对抗网络生成猫咪头像的项目,非商业用途,更像是一个用于实践对抗网络应用的玩具类型项目。作者是simoninithomas,以下是他2018年在freecodecamp就此项目发表的文章:

(https://www.freecodecamp.org/news/how-ai-can-learn-to-generate-pictures-of-cats-ba692cb6eae4/)

项目的基本原理其实不难理解,两个神经网络分别作为Generater和Discriminator相互对抗和博弈,关于其具体损失函数、优化函数、训练部分的细节可以从原文和代码中了解,在此不表。最终实现的效果,是通过3个G左右的猫猫图片训练出一个模型,可以随机生成猫猫头图像。__请注意,随机生成猫猫头仅仅意味着图片中的猫猫头被认定为猫,我们无法通过它的代码实现对于猫猫特征的任何掌控。__这意味着他只实现了一个对抗生成网络的基本功能,并无别的创新。

复现项目的准备工作

首先,下载项目源码\数据集\现成的模型。源码在ipynb中。

(https://github.com/simoninithomas/CatDCGAN)

(https://www.kaggle.com/crawford/cat-dataset)

(https://drive.google.com/drive/folders/1zdZZ91fjOUiOsIdAQKZkUTATXzqy7hiz?usp=sharing)

笔者认为,jupyter notebook难以在指定的conda虚拟python环境中运行。尽管网络上有着各种解决方案,但是笔者均没能成功解决此问题。为了方便,直接把代码整体复制到py文件中,用指定内核运行并测试。

这个项目用的是TensorFlow1,所以需要相应配置一套TensorFlow环境,1或者2都可以。此过程需要非常小心地注意TensorFlow_gpu、python、CUDA、cudnn的版本对应关系。安装环境细节可参考此文章:

(https://blog.csdn.net/weixin_43877139/article/details/100544065)

版本对应关系(windows环境下):

(https://tensorflow.google.cn/install/source_windows)

笔者用的是tensorflow_gpu-2.6.0/python3.8/cudnn8.1/cuda11.2组合。确保tensorflow能够正常识别gpu之后就可以进行数据集和源代码的修改了。

数据集预处理

按照作者说法,先要将猫脸不居中、损坏、裁剪严重、倒置、不像猫、被遮挡等情况的数据剔除。为此他准备了一段sh脚本,可是那是存在一些小问题的。我们可以先将下载好的数据集解压到./并将文件夹命名为“cat_dataset”,然后将./cat_dataset/cats这个文件夹剪切出来到./。接下来下载“preprocess_cat_dataset.py”放到./,项目链接如下:

(https://github.com/AlexiaJM/relativistic-f-divergences)

最后在./保存如下脚本,并在git bash中运行它。

mv cat_dataset/CAT_00/* cat_dataset
rmdir cat_dataset/CAT_00
mv cat_dataset/CAT_01/* cat_dataset
rmdir cat_dataset/CAT_01
mv cat_dataset/CAT_02/* cat_dataset
rmdir cat_dataset/CAT_02
mv cat_dataset/CAT_03/* cat_dataset
rmdir cat_dataset/CAT_03
mv cat_dataset/CAT_04/* cat_dataset
rmdir cat_dataset/CAT_04
mv cat_dataset/CAT_05/* cat_dataset
rmdir cat_dataset/CAT_05
mv cat_dataset/CAT_06/* cat_dataset
rmdir cat_dataset/CAT_06 ## Error correction
rm cat_dataset/00000003_019.jpg.cat
mv 00000003_015.jpg.cat cat_dataset/00000003_015.jpg.cat ## Removing outliers
# Corrupted, drawings, badly cropped, inverted, impossible to tell it's a cat, blocked face
cd cat_dataset
rm 00000004_007.jpg 00000007_002.jpg 00000045_028.jpg 00000050_014.jpg 00000056_013.jpg 00000059_002.jpg 00000108_005.jpg 00000122_023.jpg 00000126_005.jpg 00000132_018.jpg 00000142_024.jpg 00000142_029.jpg 00000143_003.jpg 00000145_021.jpg 00000166_021.jpg 00000169_021.jpg 00000186_002.jpg 00000202_022.jpg 00000208_023.jpg 00000210_003.jpg 00000229_005.jpg 00000236_025.jpg 00000249_016.jpg 00000254_013.jpg 00000260_019.jpg 00000261_029.jpg 00000265_029.jpg 00000271_020.jpg 00000282_026.jpg 00000316_004.jpg 00000352_014.jpg 00000400_026.jpg 00000406_006.jpg 00000431_024.jpg 00000443_027.jpg 00000502_015.jpg 00000504_012.jpg 00000510_019.jpg 00000514_016.jpg 00000514_008.jpg 00000515_021.jpg 00000519_015.jpg 00000522_016.jpg 00000523_021.jpg 00000529_005.jpg 00000556_022.jpg 00000574_011.jpg 00000581_018.jpg 00000582_011.jpg 00000588_016.jpg 00000588_019.jpg 00000590_006.jpg 00000592_018.jpg 00000593_027.jpg 00000617_013.jpg 00000618_016.jpg 00000619_025.jpg 00000622_019.jpg 00000622_021.jpg 00000630_007.jpg 00000645_016.jpg 00000656_017.jpg 00000659_000.jpg 00000660_022.jpg 00000660_029.jpg 00000661_016.jpg 00000663_005.jpg 00000672_027.jpg 00000673_027.jpg 00000675_023.jpg 00000692_006.jpg 00000800_017.jpg 00000805_004.jpg 00000807_020.jpg 00000823_010.jpg 00000824_010.jpg 00000836_008.jpg 00000843_021.jpg 00000850_025.jpg 00000862_017.jpg 00000864_007.jpg 00000865_015.jpg 00000870_007.jpg 00000877_014.jpg 00000882_013.jpg 00000887_028.jpg 00000893_022.jpg 00000907_013.jpg 00000921_029.jpg 00000929_022.jpg 00000934_006.jpg 00000960_021.jpg 00000976_004.jpg 00000987_000.jpg 00000993_009.jpg 00001006_014.jpg 00001008_013.jpg 00001012_019.jpg 00001014_005.jpg 00001020_017.jpg 00001039_008.jpg 00001039_023.jpg 00001048_029.jpg 00001057_003.jpg 00001068_005.jpg 00001113_015.jpg 00001140_007.jpg 00001157_029.jpg 00001158_000.jpg 00001167_007.jpg 00001184_007.jpg 00001188_019.jpg 00001204_027.jpg 00001205_022.jpg 00001219_005.jpg 00001243_010.jpg 00001261_005.jpg 00001270_028.jpg 00001274_006.jpg 00001293_015.jpg 00001312_021.jpg 00001365_026.jpg 00001372_006.jpg 00001379_018.jpg 00001388_024.jpg 00001389_026.jpg 00001418_028.jpg 00001425_012.jpg 00001431_001.jpg 00001456_018.jpg 00001458_003.jpg 00001468_019.jpg 00001475_009.jpg 00001487_020.jpg
rm 00000004_007.jpg.cat 00000007_002.jpg.cat 00000045_028.jpg.cat 00000050_014.jpg.cat 00000056_013.jpg.cat 00000059_002.jpg.cat 00000108_005.jpg.cat 00000122_023.jpg.cat 00000126_005.jpg.cat 00000132_018.jpg.cat 00000142_024.jpg.cat 00000142_029.jpg.cat 00000143_003.jpg.cat 00000145_021.jpg.cat 00000166_021.jpg.cat 00000169_021.jpg.cat 00000186_002.jpg.cat 00000202_022.jpg.cat 00000208_023.jpg.cat 00000210_003.jpg.cat 00000229_005.jpg.cat 00000236_025.jpg.cat 00000249_016.jpg.cat 00000254_013.jpg.cat 00000260_019.jpg.cat 00000261_029.jpg.cat 00000265_029.jpg.cat 00000271_020.jpg.cat 00000282_026.jpg.cat 00000316_004.jpg.cat 00000352_014.jpg.cat 00000400_026.jpg.cat 00000406_006.jpg.cat 00000431_024.jpg.cat 00000443_027.jpg.cat 00000502_015.jpg.cat 00000504_012.jpg.cat 00000510_019.jpg.cat 00000514_016.jpg.cat 00000514_008.jpg.cat 00000515_021.jpg.cat 00000519_015.jpg.cat 00000522_016.jpg.cat 00000523_021.jpg.cat 00000529_005.jpg.cat 00000556_022.jpg.cat 00000574_011.jpg.cat 00000581_018.jpg.cat 00000582_011.jpg.cat 00000588_016.jpg.cat 00000588_019.jpg.cat 00000590_006.jpg.cat 00000592_018.jpg.cat 00000593_027.jpg.cat 00000617_013.jpg.cat 00000618_016.jpg.cat 00000619_025.jpg.cat 00000622_019.jpg.cat 00000622_021.jpg.cat 00000630_007.jpg.cat 00000645_016.jpg.cat 00000656_017.jpg.cat 00000659_000.jpg.cat 00000660_022.jpg.cat 00000660_029.jpg.cat 00000661_016.jpg.cat 00000663_005.jpg.cat 00000672_027.jpg.cat 00000673_027.jpg.cat 00000675_023.jpg.cat 00000692_006.jpg.cat 00000800_017.jpg.cat 00000805_004.jpg.cat 00000807_020.jpg.cat 00000823_010.jpg.cat 00000824_010.jpg.cat 00000836_008.jpg.cat 00000843_021.jpg.cat 00000850_025.jpg.cat 00000862_017.jpg.cat 00000864_007.jpg.cat 00000865_015.jpg.cat 00000870_007.jpg.cat 00000877_014.jpg.cat 00000882_013.jpg.cat 00000887_028.jpg.cat 00000893_022.jpg.cat 00000907_013.jpg.cat 00000921_029.jpg.cat 00000929_022.jpg.cat 00000934_006.jpg.cat 00000960_021.jpg.cat 00000976_004.jpg.cat 00000987_000.jpg.cat 00000993_009.jpg.cat 00001006_014.jpg.cat 00001008_013.jpg.cat 00001012_019.jpg.cat 00001014_005.jpg.cat 00001020_017.jpg.cat 00001039_008.jpg.cat 00001039_023.jpg.cat 00001048_029.jpg.cat 00001057_003.jpg.cat 00001068_005.jpg.cat 00001113_015.jpg.cat 00001140_007.jpg.cat 00001157_029.jpg.cat 00001158_000.jpg.cat 00001167_007.jpg.cat 00001184_007.jpg.cat 00001188_019.jpg.cat 00001204_027.jpg.cat 00001205_022.jpg.cat 00001219_005.jpg.cat 00001243_010.jpg.cat 00001261_005.jpg.cat 00001270_028.jpg.cat 00001274_006.jpg.cat 00001293_015.jpg.cat 00001312_021.jpg.cat 00001365_026.jpg.cat 00001372_006.jpg.cat 00001379_018.jpg.cat 00001388_024.jpg.cat 00001389_026.jpg.cat 00001418_028.jpg.cat 00001425_012.jpg.cat 00001431_001.jpg.cat 00001456_018.jpg.cat 00001458_003.jpg.cat 00001468_019.jpg.cat 00001475_009.jpg.cat 00001487_020.jpg.cat
cd .. ## Preprocessing and putting in folders for different image sizes
mkdir cats_bigger_than_64x64
mkdir cats_bigger_than_128x128
python preprocess_cat_dataset.py ## Removing cat_dataset
rm -r cat_dataset

为了让代码正常运行,还需要在./中创建两个空文件夹,笔者将它们命名为from_checkpoint_IMG和images。其实这种命名是非常不规范的,但是鉴于尽量与代码表意相吻合,暂且就这样命名。

代码修改

1、首先,本项目用的是tf1,笔者装的是tf2。简单解决这个问题的方法是:

import tensorflow as tf

改为

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

题外话,如何区分代码用的是tf1还是2?如果出现X=tf.placeholder(“float”)这类语句,那就是tf1,它在tf2下会出现“module ‘tensorflow’ has no attribute ‘placeholder’”的报错。

2、下面两个变量,第一个代表是否需要对图片进行尺寸的统一化处理,这部分内容只需要处理过一遍就行了,也就是说,第一次运行时改为True,之后一直改为False。第二个变量代表是否利用已经训练好的模型,如果重新驯良,需要花费至少20小时的时间,所以我们如果不需要重新训练的话,就写True。

do_preprocess = False
from_checkpoint = False

模型保存的路径是./models/。我们可以将准备工作时下载的模型包解压到这个路径下。

3、修改代码错误。

if from_checkpoint == True:
saver.restore(sess, "./models/model.ckpt") show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False)

image_path未定义。我们给它加一句image_path = "./from_checkpoint_IMG/FCI.jpg"。这个路径也就是上一环节的那个空文件夹。

return losses, samples

这两个都是未定义。sample根本没用上,直接删了。losses本来是想创建一个数组,记录损失函数随着时间的变化,最后用图表反应变化情况的,但是代码里面没这部分内容。因此可在train函数中,定义:

losses = []
...
if i % 10 == 0:
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images})
train_loss_g = g_loss.eval({input_z: batch_z})
losses.append((train_loss_d, train_loss_g))
...

这里其实还是有问题的。假如直接利用已有的模型进行生成,这个数组就是空的,最后会报个错。但是直接利用模型本来就不需要分析损失函数的变化情况,所以它最后报不报错都无所谓了,这个也就不需要再去在意。

最后的生成图片在./from_checkpoint_IMG/中。

完整代码展示

import os
# import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np import helper
from glob import glob
import pickle as pkl
import scipy.misc import time import cv2
import matplotlib.pyplot as plt
#%matplotlib inline #do_preprocess = False
#from_checkpoint = False
do_preprocess = False
from_checkpoint = True data_dir = './cats_bigger_than_128x128' # Data
data_resized_dir = "./resized_data"# Resized data if do_preprocess == True:
os.mkdir(data_resized_dir) for each in os.listdir(data_dir):
image = cv2.imread(os.path.join(data_dir, each))
image = cv2.resize(image, (128, 128))
cv2.imwrite(os.path.join(data_resized_dir, each), image) # This part was taken from Udacity Face generator project
def get_image(image_path, width, height, mode):
"""
Read image from image_path
:param image_path: Path of image
:param width: Width of image
:param height: Height of image
:param mode: Mode of image
:return: Image data
"""
image = Image.open(image_path) return np.array(image.convert(mode)) def get_batch(image_files, width, height, mode):
data_batch = np.array(
[get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) # Make sure the images are in 4 dimensions
if len(data_batch.shape) < 4:
data_batch = data_batch.reshape(data_batch.shape + (1,)) return data_batch show_n_images = 25
mnist_images = helper.get_batch(glob(os.path.join(data_resized_dir, '*.jpg'))[:show_n_images], 64, 64, 'RGB')
plt.imshow(helper.images_square_grid(mnist_images, 'RGB')) # Taken from Udacity face generator project
from distutils.version import LooseVersion
import warnings
# import tensorflow as tf # Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer. You are using {}'.format(tf.__version__)
print('TensorFlow Version: {}'.format(tf.__version__)) # Check for a GPU
if not tf.test.gpu_device_name():
warnings.warn('No GPU found. Please use a GPU to train your neural network.')
else:
print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) def model_inputs(real_dim, z_dim):
"""
Create the model inputs
:param real_dim: tuple containing width, height and channels
:param z_dim: The dimension of Z
:return: Tuple of (tensor of real input images, tensor of z data, learning rate G, learning rate D)
"""
inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='inputs_real')
inputs_z = tf.placeholder(tf.float32, (None, z_dim), name="input_z")
learning_rate_G = tf.placeholder(tf.float32, name="learning_rate_G")
learning_rate_D = tf.placeholder(tf.float32, name="learning_rate_D") # inputs_real = tf.compat.v1.placeholder(tf.float32, (None, *real_dim), name='inputs_real')
# inputs_z = tf.compat.v1.placeholder(tf.float32, (None, z_dim), name="input_z")
# learning_rate_G = tf.compat.v1.placeholder(tf.float32, name="learning_rate_G")
# learning_rate_D = tf.compat.v1.placeholder(tf.float32, name="learning_rate_D") return inputs_real, inputs_z, learning_rate_G, learning_rate_D def generator(z, output_channel_dim, is_train=True):
''' Build the generator network. Arguments
---------
z : Input tensor for the generator
output_channel_dim : Shape of the generator output
n_units : Number of units in hidden layer
reuse : Reuse the variables with tf.variable_scope
alpha : leak parameter for leaky ReLU Returns
-------
out:
'''
with tf.variable_scope("generator", reuse= not is_train): # First FC layer --> 8x8x1024
fc1 = tf.layers.dense(z, 8*8*1024) # Reshape it
fc1 = tf.reshape(fc1, (-1, 8, 8, 1024)) # Leaky ReLU
fc1 = tf.nn.leaky_relu(fc1, alpha=alpha) # Transposed conv 1 --> BatchNorm --> LeakyReLU
# 8x8x1024 --> 16x16x512
trans_conv1 = tf.layers.conv2d_transpose(inputs = fc1,
filters = 512,
kernel_size = [5,5],
strides = [2,2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name="trans_conv1") batch_trans_conv1 = tf.layers.batch_normalization(inputs = trans_conv1, training=is_train, epsilon=1e-5, name="batch_trans_conv1") trans_conv1_out = tf.nn.leaky_relu(batch_trans_conv1, alpha=alpha, name="trans_conv1_out") # Transposed conv 2 --> BatchNorm --> LeakyReLU
# 16x16x512 --> 32x32x256
trans_conv2 = tf.layers.conv2d_transpose(inputs = trans_conv1_out,
filters = 256,
kernel_size = [5,5],
strides = [2,2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name="trans_conv2") batch_trans_conv2 = tf.layers.batch_normalization(inputs = trans_conv2, training=is_train, epsilon=1e-5, name="batch_trans_conv2") trans_conv2_out = tf.nn.leaky_relu(batch_trans_conv2, alpha=alpha, name="trans_conv2_out") # Transposed conv 3 --> BatchNorm --> LeakyReLU
# 32x32x256 --> 64x64x128
trans_conv3 = tf.layers.conv2d_transpose(inputs = trans_conv2_out,
filters = 128,
kernel_size = [5,5],
strides = [2,2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name="trans_conv3") batch_trans_conv3 = tf.layers.batch_normalization(inputs = trans_conv3, training=is_train, epsilon=1e-5, name="batch_trans_conv3") trans_conv3_out = tf.nn.leaky_relu(batch_trans_conv3, alpha=alpha, name="trans_conv3_out") # Transposed conv 4 --> BatchNorm --> LeakyReLU
# 64x64x128 --> 128x128x64
trans_conv4 = tf.layers.conv2d_transpose(inputs = trans_conv3_out,
filters = 64,
kernel_size = [5,5],
strides = [2,2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name="trans_conv4") batch_trans_conv4 = tf.layers.batch_normalization(inputs = trans_conv4, training=is_train, epsilon=1e-5, name="batch_trans_conv4") trans_conv4_out = tf.nn.leaky_relu(batch_trans_conv4, alpha=alpha, name="trans_conv4_out") # Transposed conv 5 --> tanh
# 128x128x64 --> 128x128x3
logits = tf.layers.conv2d_transpose(inputs = trans_conv4_out,
filters = 3,
kernel_size = [5,5],
strides = [1,1],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name="logits") out = tf.tanh(logits, name="out") return out def discriminator(x, is_reuse=False, alpha = 0.2):
''' Build the discriminator network. Arguments
---------
x : Input tensor for the discriminator
n_units: Number of units in hidden layer
reuse : Reuse the variables with tf.variable_scope
alpha : leak parameter for leaky ReLU Returns
-------
out, logits:
'''
with tf.variable_scope("discriminator", reuse = is_reuse): # Input layer 128*128*3 --> 64x64x64
# Conv --> BatchNorm --> LeakyReLU
conv1 = tf.layers.conv2d(inputs = x,
filters = 64,
kernel_size = [5,5],
strides = [2,2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name='conv1') batch_norm1 = tf.layers.batch_normalization(conv1,
training = True,
epsilon = 1e-5,
name = 'batch_norm1') conv1_out = tf.nn.leaky_relu(batch_norm1, alpha=alpha, name="conv1_out") # 64x64x64--> 32x32x128
# Conv --> BatchNorm --> LeakyReLU
conv2 = tf.layers.conv2d(inputs = conv1_out,
filters = 128,
kernel_size = [5, 5],
strides = [2, 2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name='conv2') batch_norm2 = tf.layers.batch_normalization(conv2,
training = True,
epsilon = 1e-5,
name = 'batch_norm2') conv2_out = tf.nn.leaky_relu(batch_norm2, alpha=alpha, name="conv2_out") # 32x32x128 --> 16x16x256
# Conv --> BatchNorm --> LeakyReLU
conv3 = tf.layers.conv2d(inputs = conv2_out,
filters = 256,
kernel_size = [5, 5],
strides = [2, 2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name='conv3') batch_norm3 = tf.layers.batch_normalization(conv3,
training = True,
epsilon = 1e-5,
name = 'batch_norm3') conv3_out = tf.nn.leaky_relu(batch_norm3, alpha=alpha, name="conv3_out") # 16x16x256 --> 16x16x512
# Conv --> BatchNorm --> LeakyReLU
conv4 = tf.layers.conv2d(inputs = conv3_out,
filters = 512,
kernel_size = [5, 5],
strides = [1, 1],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name='conv4') batch_norm4 = tf.layers.batch_normalization(conv4,
training = True,
epsilon = 1e-5,
name = 'batch_norm4') conv4_out = tf.nn.leaky_relu(batch_norm4, alpha=alpha, name="conv4_out") # 16x16x512 --> 8x8x1024
# Conv --> BatchNorm --> LeakyReLU
conv5 = tf.layers.conv2d(inputs = conv4_out,
filters = 1024,
kernel_size = [5, 5],
strides = [2, 2],
padding = "SAME",
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name='conv5') batch_norm5 = tf.layers.batch_normalization(conv5,
training = True,
epsilon = 1e-5,
name = 'batch_norm5') conv5_out = tf.nn.leaky_relu(batch_norm5, alpha=alpha, name="conv5_out") # Flatten it
flatten = tf.reshape(conv5_out, (-1, 8*8*1024)) # Logits
logits = tf.layers.dense(inputs = flatten,
units = 1,
activation = None) out = tf.sigmoid(logits) return out, logits def model_loss(input_real, input_z, output_channel_dim, alpha):
"""
Get the loss for the discriminator and generator
:param input_real: Images from the real dataset
:param input_z: Z input
:param out_channel_dim: The number of channels in the output image
:return: A tuple of (discriminator loss, generator loss)
"""
# Generator network here
g_model = generator(input_z, output_channel_dim)
# g_model is the generator output # Discriminator network here
d_model_real, d_logits_real = discriminator(input_real, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model,is_reuse=True, alpha=alpha) # Calculate losses
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
labels=tf.ones_like(d_model_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.zeros_like(d_model_fake)))
d_loss = d_loss_real + d_loss_fake g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
labels=tf.ones_like(d_model_fake))) return d_loss, g_loss def model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1):
"""
Get optimization operations
:param d_loss: Discriminator loss Tensor
:param g_loss: Generator loss Tensor
:param learning_rate: Learning Rate Placeholder
:param beta1: The exponential decay rate for the 1st moment in the optimizer
:return: A tuple of (discriminator training operation, generator training operation)
"""
# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # Generator update
gen_updates = [op for op in update_ops if op.name.startswith('generator')] # Optimizers
with tf.control_dependencies(gen_updates):
d_train_opt = tf.train.AdamOptimizer(learning_rate=lr_D, beta1=beta1).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate=lr_G, beta1=beta1).minimize(g_loss, var_list=g_vars) return d_train_opt, g_train_opt def show_generator_output(sess, n_images, input_z, out_channel_dim, image_mode, image_path, save, show):
"""
Show example output for the generator
:param sess: TensorFlow session
:param n_images: Number of Images to display
:param input_z: Input Z Tensor
:param out_channel_dim: The number of channels in the output image
:param image_mode: The mode to use for images ("RGB" or "L")
:param image_path: Path to save the image
"""
cmap = None if image_mode == 'RGB' else 'gray'
z_dim = input_z.get_shape().as_list()[-1]
example_z = np.random.uniform(-1, 1, size=[n_images, z_dim]) samples = sess.run(
generator(input_z, out_channel_dim, False),
feed_dict={input_z: example_z}) images_grid = helper.images_square_grid(samples, image_mode) if save == True:
# Save image
images_grid.save(image_path, 'JPEG') if show == True:
plt.imshow(images_grid, cmap=cmap)
plt.show() def train(epoch_count, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha):
"""
Train the GAN
:param epoch_count: Number of epochs
:param batch_size: Batch Size
:param z_dim: Z dimension
:param learning_rate: Learning Rate
:param beta1: The exponential decay rate for the 1st moment in the optimizer
:param get_batches: Function to get batches
:param data_shape: Shape of the data
:param data_image_mode: The image mode to use for images ("RGB" or "L")
""" samples, losses = [], [] # Create our input placeholders
input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], z_dim) # Losses
d_loss, g_loss = model_loss(input_images, input_z, data_shape[3], alpha) # Optimizers
d_opt, g_opt = model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1) i = 0 version = "firstTrain"
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # Saver
saver = tf.train.Saver() num_epoch = 0 if from_checkpoint == True:
saver.restore(sess, "./models/model.ckpt")
image_path = "./from_checkpoint_IMG/FCI.jpg"
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) else:
for epoch_i in range(epoch_count):
num_epoch += 1 if num_epoch % 5 == 0: # Save model every 5 epochs
#if not os.path.exists("models/" + version):
# os.makedirs("models/" + version)
save_path = saver.save(sess, "./models/model.ckpt")
print("Model saved") for batch_images in get_batches(batch_size):
# Random noise
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) i += 1 # Run optimizers
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: learning_rate_D})
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: learning_rate_G}) if i % 10 == 0:
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images})
train_loss_g = g_loss.eval({input_z: batch_z}) # Save losses to view after training
losses.append((train_loss_d, train_loss_g)) # Save it
image_name = str(i) + ".jpg"
image_path = "./images/" + image_name
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) # Print every 5 epochs (for stability overwize the jupyter notebook will bug)
if i % 1500 == 0: image_name = str(i) + ".jpg"
image_path = "./images/" + image_name
print("Epoch {}/{}...".format(epoch_i+1, epochs),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, False, True) # return losses, samples
return losses # Size input image for discriminator
real_size = (128,128,3) # Size of latent vector to generator
z_dim = 100
learning_rate_D = .00005 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/
learning_rate_G = 2e-4 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/
batch_size = 64
epochs = 215
alpha = 0.2
beta1 = 0.5 # Create the network
#model = DGAN(real_size, z_size, learning_rate, alpha, beta1) # Load the data and train the network here
dataset = helper.Dataset(glob(os.path.join(data_resized_dir, '*.jpg'))) # with tf.Graph().as_default():
# losses, samples = train(epochs, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, dataset.get_batches,
# dataset.shape, dataset.image_mode, alpha) with tf.Graph().as_default():
losses = train(epochs, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, dataset.get_batches,
dataset.shape, dataset.image_mode, alpha) fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator', alpha=0.5)
plt.plot(losses.T[1], label='Generator', alpha=0.5)
plt.title("Training Losses")
plt.legend()
plt.show()

后记

项目本身没有什么亮点,但是挺好玩的。由于本科毕设需要用到此类技术,接下来笔者会考虑更多接触pytorch应用和理论知识的学习。另外,需要仔细思考在此基础上如何开发出更新颖、更具实用价值的功能,不然就显得太单调了。

CatDCGAN项目复现与对抗网络初识的更多相关文章

  1. (转)【重磅】无监督学习生成式对抗网络突破,OpenAI 5大项目落地

    [重磅]无监督学习生成式对抗网络突破,OpenAI 5大项目落地 [新智元导读]"生成对抗网络是切片面包发明以来最令人激动的事情!"LeCun前不久在Quroa答问时毫不加掩饰对生 ...

  2. 【超分辨率】—(ESRGAN)增强型超分辨率生成对抗网络-解读与实现

    一.文献解读 我们知道GAN 在图像修复时更容易得到符合视觉上效果更好的图像,今天要介绍的这篇文章——ESRGAN: Enhanced Super-Resolution Generative Adve ...

  3. GAN实战笔记——第六章渐进式增长生成对抗网络(PGGAN)

    渐进式增长生成对抗网络(PGGAN) 使用 TensorFlow和 TensorFlow Hub( TFHUB)构建渐进式增长生成对抗网络( Progressive GAN, PGGAN或 PROGA ...

  4. 对抗网络GAN的应用实例

      https://sigmoidal.io/beginners-review-of-gan-architectures/ 嗨,大家好!像许多追随AI进展的人一样,我无法忽略生成建模的最新进展,尤其是 ...

  5. 知物由学 | AI网络安全实战:生成对抗网络

    本文由  网易云发布. “知物由学”是网易云易盾打造的一个品牌栏目,词语出自汉·王充<论衡·实知>.人,能力有高下之分,学习才知道事物的道理,而后才有智慧,不去求问就不会知道.“知物由学” ...

  6. TensorFlow从1到2(十二)生成对抗网络GAN和图片自动生成

    生成对抗网络的概念 上一篇中介绍的VAE自动编码器具备了一定程度的创造特征,能够"无中生有"的由一组随机数向量生成手写字符的图片. 这个"创造能力"我们在模型中 ...

  7. GAN生成式对抗网络(四)——SRGAN超高分辨率图片重构

    论文pdf 地址:https://arxiv.org/pdf/1609.04802v1.pdf 我的实际效果 清晰度距离我的期待有距离. 颜色上面存在差距. 解决想法 增加一个颜色判别器.将颜色值反馈 ...

  8. 生成对抗网络(GAN)的18个绝妙应用

    https://juejin.im/post/5d3fb44e6fb9a06b2e3ccd4e 生成对抗网络(GAN)是生成模型的一种神经网络架构. 生成模型指在现存样本的基础上,使用模型来生成新案例 ...

  9. 科普 | ​生成对抗网络(GAN)的发展史

    来源:https://en.wikipedia.org/wiki/Edmond_de_Belamy 五年前,Generative Adversarial Networks(GANs)在深度学习领域掀起 ...

  10. 人工智能中小样本问题相关的系列模型演变及学习笔记(二):生成对抗网络 GAN

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![握手][握手] [再啰嗦一下]本文衔接上一个随笔:人工智能中小样本问题相关的系列模型演变及学习 ...

随机推荐

  1. flex height变高了

    在做移动端项目时,使用了flex布局后,所有的子项高度变成了一致 问题:在flex布局中,如何保持子项自身高度 原因: Flex 布局会默认: 把所有子项变成水平排列.默认不自动换行.让子项与其内容等 ...

  2. 通过xshell在linux上安装nginx1.12.0

    0)环境安装 Nginx是C语言开发,建议在 Linux 上运行,当然,也可以安装 Windows 版本,本篇则使用 CentOS 7 作为安装环境. 0.1 gcc 安装 安装 nginx 需要先将 ...

  3. vue项目运行出现warnings potentially fixable with the `--fix` option的报错问题

    vue-cil3 运行报错 warnings potentially fixable with the `--fix` option. 解决办法:"lint": "vue ...

  4. 权昌TSC条码打印机终极使用教程与开发版本代码大全

    本教程使用的打印机型号:TSC TTP-244 Plus 官方文档 一.TSC打印机安装 1.机器安装 根据官方快速安装指南安装打印机,此处不详细说明,也可以看视频教程,唯一需要注意的地方就是碳带的方 ...

  5. 安装MogHA

    # 一.关于MogHAMogHA 是云和恩墨基于 MogDB 同步异步流复制技术自研的一款保障数据库主备集群高可用的企业级软件系统 (适用于 MogDB 和 openGauss 数据库) MogHA ...

  6. 1.CD冷却效果

    CD冷却效果.. 一.将需要用到的图片复制到 PS 中做去色处理,将图片保存为 PNG 格式.如下 二.将制作好的图片导入 Unity 中,做成图集 三.在虚拟按键上添加 UI - Image 制作 ...

  7. P1067 [NOIP2009 普及组] 多项式输出 题解

    目录 题目 要注意的地方 code 题目 P1067 [NOIP2009 普及组] 多项式输出 要注意的地方 是不是常数项 是不是1 或 -1 输入的是不是第一个 ^1 负数 code #includ ...

  8. 2347. 最好的扑克手牌 (Easy)

    问题描述 2347. 最好的扑克手牌 (Easy) 给你一个整数数组 ranks 和一个字符数组 suit .你有 5 张扑克牌,第 i 张牌大小为 ranks[i] ,花色为 suits[i] . ...

  9. ASPNET Core ActionFilterAttribute中断后续请求

    转载自:https://www.cnblogs.com/luconsole/p/4346669.html SPNET MVC如何正确的中断请求? 感觉是这样? 在aspnet开发过程中如果想要中断当前 ...

  10. python 本地l离线安装whl文件

    记录下无网络时安装Python环境 一: 单独下载文件 1.下载whl离线文件到本地,放到c盘根目录(任意位置均可,只是方便安装) https://pypi.org/ https://www.lfd. ...