import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile # 原始输入数据的目录,这个目录下有5个子目录,每个子目录底下保存这属于该
# 类别的所有图片。
INPUT_DATA = 'F:\\TensorFlowGoogle\\201806-github\\datasets\\flower_photos\\'
# 输出文件地址。我们将整理后的图片数据通过numpy的格式保存。
OUTPUT_FILE = 'F:\\shuju\\flower_processed_data.npy' # 测试数据和验证数据比例。
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10 # 读取数据并将数据分割成训练数据、验证数据和测试数据。
def create_image_lists(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True
# 初始化各个数据集。
training_images = []
training_labels = []
testing_images = []
testing_labels = []
validation_images = []
validation_labels = []
current_label = 0 # 读取所有的子目录。
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取一个子目录中所有的图片文件。
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
print("processing:", dir_name)
i = 0
# 处理图片数据。
for file_name in file_list:
i += 1
# 读取并解析图片,将图片转化为299*299以方便inception-v3模型来处理。
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
image = tf.image.decode_jpeg(image_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize_images(image, [299, 299])
image_value = sess.run(image)
# 随机划分数据聚。
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
testing_labels.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
if i % 200 == 0:
print(i, "images processed.")
current_label += 1
# 将训练数据随机打乱以获得更好的训练效果。
state = np.random.get_state()
np.random.shuffle(training_images)
np.random.set_state(state)
np.random.shuffle(training_labels) return np.asarray([training_images, training_labels,validation_images, validation_labels,testing_images, testing_labels]) with tf.Session() as sess:
processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
# 通过numpy格式保存处理后的数据。
np.save(OUTPUT_FILE, processed_data)

吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(1)的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  2. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  3. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  4. 吴裕雄 python 神经网络——TensorFlow 花瓣识别2

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  5. 吴裕雄 python 神经网络——TensorFlow训练神经网络:花瓣识别

    import os import glob import os.path import numpy as np import tensorflow as tf from tensorflow.pyth ...

  6. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  7. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  8. 吴裕雄 PYTHON 神经网络——TENSORFLOW 无监督学习处理MNIST手写数字数据集

    # 导入模块 import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 加载数据 from tensor ...

  9. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

随机推荐

  1. win10无法运行Vmware,怎么办

    在win10上安装了Vmware12.15都无法运行,怎么办,找了一晚上 1.点击“检查更新”或去官网下载最新版本 Vmware15.5.0,或者百度云下载 链接:https://pan.baidu. ...

  2. 题解【洛谷P1995】口袋的天空

    题面 题解 从图中删边,直到图中只剩\(k\)条边,计算权值之和即可. 代码 #include <iostream> #include <cstdio> #include &l ...

  3. C++ - cpprestsdk

    Windows 安装方法: CMake 1.32+,生成过程会将 vcpkg 下载好,配置到系统环境变量,然后用 vcpkg 安装依赖库(github 上有列出需要的依赖库). Github 上的示例 ...

  4. Codeforces 1204D2. Kirk and a Binary String (hard version) (dp思路)

    题目链接:http://codeforces.com/contest/1204/problem/D2 题目是给定一个01字符串,让你尽可能多地改变1变为0,但是要保证新的字符串,对任意的L,R使得Sl ...

  5. ES5中, map 和 forEach的区别

    forEach和map区别在哪里知道吗? // forEach Array.prototype.forEach(callback(item, index, thisArr), thisArg) // ...

  6. springboot08(springmvc自动配置原理)

    MVC WebMvcAutoConfiguration.java @ConditionalOnMissingBean(name = "viewResolver", value = ...

  7. Linux实现树莓派3B的国密SM9算法交叉编译——(一)环境部署、简单测试与eclipse工程项目测试

    这篇文章主要介绍了交叉编译的实现,包括环境部署,并简单测试交叉编译环境是否安装成功. 一.交叉编译 在一个平台上生成另一个平台上的可执行代码.为什么要大费周折的进行交叉编译呢?一句话:不得已而为之.有 ...

  8. Jmeter注册100个账户的三个方法

    Jmeter注册账户比如注册成千上万个账户,如何快速实现呢? 三种方法分别举例注册5个账户 1)添加CSV data config_txt 2)添加CSV data config_csv 3)函数助手 ...

  9. go基础_切片

    切片创建方式 1.通过数组创建 2.通过内置函数make创建 切片允许的操作 1.追加元素 2.通过内置函数make创建 package main import "fmt" fun ...

  10. Python变量理解

    变量进阶(理解) 01. 变量的引用 变量 和 数据 都是保存在 内存 中的 在 Python 中 函数 的 参数传递 以及 返回值 都是靠 引用 传递的 1.1 引用的概念 在 Python 中 变 ...