import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim # 因为slim.nets包在 tensorflow 1.3 中有一些问题,所以这里为了方便
# 我们将slim.nets.inception_v3中的代码拷贝到了同一个文件夹下。
# import inception_v3 # 加载通过TensorFlow-Slim定义好的inception_v3模型。
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3 # 处理好之后的数据文件。
INPUT_DATA = 'E\\flower_processed_data\\flower_processed_data.npy'
# 保存训练好的模型的路径。这里我们可以将使用新数据训练得到的完整模型保存
# 下来,如果计算资源充足,我们还可以在训练完最后的全联接层之后再训练所有
# 网络层,这样可以使得新模型更加贴近新数据。
TRAIN_FILE = 'E\\train_dir\\model'
# 谷歌提供的训练好的模型文件地址。
CKPT_FILE = 'E:\\inception_v3\\inception_v3.ckpt' # 定义训练中使用的参数。
LEARNING_RATE = 0.01
STEPS = 5000
BATCH = 128
N_CLASSES = 5 # 不需要从谷歌训练好的模型中加载的参数。这里就是最后的全联接层,因为在
# 新的问题中我们要重新训练这一层中的参数。这里给出的是参数的前缀。
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要训练的网络层参数明层,在fine-tuning的过程中就是最后的全联接层。
# 这里给出的是参数的前缀。
TRAINABLE_SCOPES='InceptionV3/Logits' # 获取所有需要从谷歌训练好的模型中加载的参数。
def get_tuned_variables():
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')] variables_to_restore = []
# 枚举inception-v3模型中所有的参数,然后判断是否需要从加载列表中
# 移除。
for var in slim.get_model_variables():
print var.op.name
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore # 获取所有需要训练的变量列表。
def get_trainable_variables():
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
variables_to_train = []
# 枚举所有需要训练的参数前缀,并通过这些前缀找到所有的参数。
for scope in scopes:
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train def main():
# 加载预处理好的数据。
processed_data = np.load(INPUT_DATA)
training_images = processed_data[0]
n_training_example = len(training_images)
training_labels = processed_data[1]
validation_images = processed_data[2]
validation_labels = processed_data[3]
testing_images = processed_data[4]
testing_labels = processed_data[5] # 定义inception-v3的输入,images为输入图片,labels为每一张图片
# 对应的标签。
images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')
labels = tf.placeholder(tf.int64, [None], name='labels') # 定义inception-v3模型。因为谷歌给出的只有模型参数取值,所以这里
# 需要在这个代码中定义inception-v3的模型结构。因为模型
# 中使用到了dropout,所以需要定一个训练时使用的模型,一个测试时
# 使用的模型。
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
train_logits, _ = inception_v3.inception_v3(
images, num_classes=N_CLASSES, is_training=True)
# 定义测试使用的模型时需要将reuse设置为True。
test_logits, _ = inception_v3.inception_v3(
images, num_classes=N_CLASSES, is_training=False, reuse=True) trainable_variables = get_trainable_variables()
print(trainable_variables) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=train_logits, labels=tf.one_hot(labels, N_CLASSES))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean,var_list=trainable_variables) # 计算正确率。
with tf.name_scope('evaluation'):
correct_prediction = tf.equal(tf.argmax(test_logits, 1), labels)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
loader = tf.train.Saver(get_tuned_variables())
saver = tf.train.Saver()
with tf.variable_scope("InceptionV3", reuse = True):
check1 = tf.get_variable("Conv2d_1a_3x3/weights")
check2 = tf.get_variable("Logits/Conv2d_1c_1x1/weights") with tf.Session() as sess:
# 初始化没有加载进来的变量。
init = tf.global_variables_initializer()
sess.run(init)
print sess.run(check1)
print sess.run(check2) # 加载谷歌已经训练好的模型。
print('Loading tuned variables from %s' % CKPT_FILE)
loader.restore(sess, CKPT_FILE)
start = 0
end = BATCH
for i in range(STEPS):
print sess.run(check1)
print sess.run(check2)
_, loss = sess.run([train_step, cross_entropy_mean], feed_dict={images: training_images[start:end], labels: training_labels[start:end]})
if i % 100 == 0 or i + 1 == STEPS:
saver.save(sess, TRAIN_FILE, global_step=i)
validation_accuracy = sess.run(evaluation_step, feed_dict={
images: validation_images, labels: validation_labels})
print('Step %d: Training loss is %.1f%% Validation accuracy = %.1f%%' % (i, loss * 100.0, validation_accuracy * 100.0))
start = end
if start == n_training_example:
start = 0
end = start + BATCH
if end > n_training_example:
end = n_training_example # 在最后的测试数据上测试正确率。
test_accuracy = sess.run(evaluation_step, feed_dict={images: test_images, labels: test_labels})
print('Final test accuracy = %.1f%%' % (test_accuracy * 100)) if __name__ == '__main__':
main()

吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  2. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  3. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(1)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  4. 吴裕雄 python 神经网络——TensorFlow 花瓣识别2

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  5. 吴裕雄 python 神经网络——TensorFlow训练神经网络:花瓣识别

    import os import glob import os.path import numpy as np import tensorflow as tf from tensorflow.pyth ...

  6. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  7. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  8. 吴裕雄 PYTHON 神经网络——TENSORFLOW 无监督学习处理MNIST手写数字数据集

    # 导入模块 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 加载数据 from tensor ...

  9. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

随机推荐

  1. [lua]紫猫lua教程-命令宝典-L1-01-12. 临时补充2

    1.lua的环境变量和函数 (1)_G表  (个人习惯遍历下_G 看看当前环境支持什么库 很多库不是默认就支持的 需要按照流程导入或者加载) 一个全局变量(非函数),内部储存有当前所有的全局函数和全局 ...

  2. C语言 fgets

    C语言 fgets #include <stdio.h> char *fgets(char *s, int size, FILE *stream); 功能:从stream指定的文件内读入字 ...

  3. 【资源分享】Gmod动态方框透视脚本

    *----------------------------------------------[下载区]----------------------------------------------* ...

  4. AC3 mantissa quantization and decoding

    1.overview 所有的mantissa被quantize到固定精确度的level(有相应的bap标识)上,level小于等于15时,使用symmetric quantization.level大 ...

  5. 记录集导出到Excel方法

    记录集导出到Excel方法   Public Function ExportToExcel(RSrecord As ADODB.Recordset, Titles_Name)'============ ...

  6. Linux下制作和使用静态库和动态库

    概述 Linux操作系统支持的函数库分为静态库和动态库,动态库又称共享库.linux系统有几个重要的目录存放相应的函数库,如/lib /usr/lib. 静态函数库: 这类库的名字一般是libxxx. ...

  7. 线索二叉树的详细实现(C++)

    线索二叉树概述 二叉树虽然是非线性结构,但二叉树的遍历却为二又树的结点集导出了一个线性序列.希望很快找到某一结点的前驱或后继,但不希望每次都要对二叉树遍历一遍,这就需要把每个结点的前驱和后继信息记录下 ...

  8. docker里面安装sqlserver2017

    首先要注意,docker一般不做数据持久化容器.如果非要安装可以参考微软官方教程: https://docs.microsoft.com/zh-cn/sql/linux/quickstart-inst ...

  9. Git 把码云上被fork项目源码merge到fork出来的分支项目

    Git 把码云上被fork项目源码merge到fork出来的分支项目 By:授客 QQ:1033553122 需求描述 被fork的项目有更新代码,希望把更新的代码merge到fork分支项目 解决方 ...

  10. VMware 搭建linux虚拟机环境

    1.任务管理器-服务 确认VMware服务是否启动 2.VMware生成网关地址 编辑--虚拟网络编辑器 VMnet8 NAT设置子网IP,子网掩码,网关 3.windows网络--更改适配器设置-- ...