import os
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile # 原始输入数据的目录,这个目录下有5个子目录,每个子目录底下保存这属于该
# 类别的所有图片。
INPUT_DATA = 'F:\\TensorFlowGoogle\\201806-github\\datasets\\flower_photos'
# 输出文件地址。我们将整理后的图片数据通过numpy的格式保存。
OUTPUT_FILE = 'F:\\flower_processed_data.npy' # 测试数据和验证数据比例。
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10 # 读取数据并将数据分割成训练数据、验证数据和测试数据。
def create_image_lists(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True # 初始化各个数据集。
training_images = []
training_labels = []
testing_images = []
testing_labels = []
validation_images = []
validation_labels = []
current_label = 0 # 读取所有的子目录。
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取一个子目录中所有的图片文件。
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
if not file_list: continue
print("processing:", dir_name) i = 0
# 处理图片数据。
for file_name in file_list:
i += 1
# 读取并解析图片,将图片转化为299*299以方便inception-v3模型来处理。
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_images(image, [299, 299])
image_value = sess.run(image) # 随机划分数据聚。
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
testing_labels.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
if i % 200 == 0:
print(i, "images processed.")
current_label += 1 # 将训练数据随机打乱以获得更好的训练效果。
state = np.random.get_state()
np.random.shuffle(training_images)
np.random.set_state(state)
np.random.shuffle(training_labels)
return np.asarray([training_images, training_labels,validation_images, validation_labels,testing_images, testing_labels]) with tf.Session() as sess:
processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
# 通过numpy格式保存处理后的数据。
np.save(OUTPUT_FILE, processed_data)

吴裕雄 python 神经网络——TensorFlow训练神经网络:花瓣识别的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用滑动平均

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  2. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用隐藏层

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  3. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用激活函数

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  4. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用指数衰减的学习率

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  5. 吴裕雄 python 神经网络——TensorFlow训练神经网络:不使用正则化

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  6. 吴裕雄 python 神经网络——TensorFlow训练神经网络:全模型

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  7. 吴裕雄 python 神经网络——TensorFlow训练神经网络:MNIST最佳实践

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  8. 吴裕雄 python 神经网络——TensorFlow训练神经网络:卷积层、池化层样例

    import numpy as np import tensorflow as tf M = np.array([ [[1],[-1],[0]], [[-1],[2],[1]], [[0],[2],[ ...

  9. 吴裕雄--天生自然 Tensorflow卷积神经网络:花朵图片识别

    import os import numpy as np import matplotlib.pyplot as plt from PIL import Image, ImageChops from ...

随机推荐

  1. D. Lunar New Year and a Wander bfs+优先队列

    D. Lunar New Year and a Wander bfs+优先队列 题意 给出一个图,从1点开始走,每个点至少要经过一次(可以很多次),每次经过一个没有走过的点就把他加到走过点序列中,问最 ...

  2. 重启nginx:端口被占用问题

    1.重启nginx出现端口占用问题: nginx: [emerg] bind() to 0.0.0.0:80 failed (98: Address already in use) 2.解决方法: 第 ...

  3. hadoop学习笔记(九):mapReduce1.x和2.x

    一.MapReduce1.0的数据分割到数据计算的过程 MapReduce是我们再进行离线大数据处理的时候经常要使用的计算模型,MapReduce的计算过程被封装的很好,我们只用使用Map和Reduc ...

  4. python-excel读取-pyodbc

    https://github.com/mkleehammer/pyodbc/wiki/Cursor 利用pyodbc读取数据库,流程基本一样,就是配置connect对象时有所不同,下面是excel的: ...

  5. eclipse中引入聚合工程

    一般我们在导入项目的时候都是直接import project, 这对普通java 项目,还是 web 项目,或者是单体的项目都是没有问题的,但是在导入聚合项目的时候这样倒入会使maven的子模块没法被 ...

  6. nomon+ pyNmonAnalyzer实现基于python的nmon监控性能数据可视化

    pip install pyNmonAnalyzer nnmon  for linux from sourceforge:https://sourceforge.net/projects/nmon/ ...

  7. ubuntu 切换用户

    app切换root ubuntu: sudo su - app sudo su - root centos : sudo su ############ root 切换app sudo su - ap ...

  8. php截取富文本框中的固定长度的字符

    ai,哎怎么赶脚自己写东西越来越小儿科了呢,现在连这个问题都找了好半天 因为后台是的内容是富文本编辑器编辑的,前台我傻逼的直接截取了字符串,然后样式啥的都乱了,找了半天是因为富文本的问题 其实解决办法 ...

  9. cookie、session以及中间件

    cookie cookie是保存客户端浏览器上的键值对,是服务端设置在客户端浏览器上的键值对,也就意味着浏览器其实可以拒绝服务端的'命令',默认情况下浏览器都是直接让服务端设置键值对 设置cookie ...

  10. 吴裕雄 python 机器学习——KNN回归KNeighborsRegressor模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import neighbors, datasets from skle ...