TensorFlow(十八):从零开始训练图片分类模型
(一):进入GitHub下载模型--》下载地址
因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。
(1):在slim中新建images文件夹存放图片集
(2):新建model文件夹用来放模型
(3):在datasets文件夹中新建myimages.py文件
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
""" from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import os
import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 这里根据自己的训练集内容进行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
} def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers. Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type. Returns:
A `Dataset` namedtuple. Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
} items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
} decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
myimages.py
(4):修改dataset_factory.py
from datasets import myimages
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'myimages':myimages, # 这一句为添加的内容
}
添加的内容
(二):对图片进行处理,生成tfrecord格式的文件。
import tensorflow as tf
import os
import random
import math
import sys #验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块数目
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/"
#标签文件名字
LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename) #判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True #获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename) photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id):
#Abstract base class for protocol messages.
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
})) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name)) #把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
assert split_name in ['train', 'test']
#计算每个数据块有多少数据
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每一个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每一个数据块最后的位置
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的id
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data, b'jpg', class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:",filenames[i])
print("Error:",e)
print("Skip it\n") sys.stdout.write('\n')
sys.stdout.flush() if __name__ == '__main__':
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print('tfcecord文件已存在')
else:
#获得所有图片以及分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
#把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_TEST:]
testing_filenames = photo_filenames[:_NUM_TEST] #数据转换
_convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
_convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #输出labels文件
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)
生成tfrecord
(三):新建批处理文件,开始训练模型
python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause 注释:
第一行表示运行训练文件,路径为全路径
第二行表示模型存放位置
第三行为创建的myimages文件名
第四行为使用的训练集
第五行为数据集所在的位置
第六行为批次大小,默认为32,看个人GPU,我用10
第七行为训练次数,默认无限次
第八行为使用模型名称
批处理文件
TensorFlow(十八):从零开始训练图片分类模型的更多相关文章
- TensorFlow(十七):训练自己的图片分类模型
(一)下载inception-v3--见TensorFlow(十四) (二)准备训练用的图片集,因为我没有图片集,所以写了个自动抓取百度图片的脚本-见抓取百度图片 (三)创建retrain.py文件, ...
- 使用tensorflow的retrain.py训练图片分类器
参考 https://hackernoon.com/creating-insanely-fast-image-classifiers-with-mobilenet-in-tensorflow-f030 ...
- 用C++调用tensorflow在python下训练好的模型(centos7)
本文主要参考博客https://blog.csdn.net/luoyexuge/article/details/80399265 [1] bazel安装参考:https://blog.csdn.net ...
- NLP(十八)利用ALBERT提升模型预测速度的一次尝试
前沿 在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,笔者介绍了如何利用tensorflow-serving部署来部署深度模型模型,在那篇文章中,笔者利用k ...
- PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
微调 Torchvision 模型 在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成.本教程将深入 ...
- 用Pytorch训练MNIST分类模型
本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...
- Tensorflow 使用slim框架下的分类模型进行分类
Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object ...
- 【emWin】例程十八:jpeg图片显示
说明:1.将文件拷入SD卡内即可在指定位置绘制jpeg图片文件,不必加载到储存器. 由于jpeg格式文件显示时需要进行解压缩,耗用动态内存,iCore3所有模块受emwin缓存的限制,jpeg ...
- 源码分析——迁移学习Inception V3网络重训练实现图片分类
1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...
随机推荐
- Manthan, Codefest 19 (open for everyone, rated, Div. 1 + Div. 2) (1208F,1208G,1208H)
1208 F 大意: 给定序列$a$, 求$\text{$a_i$|$a_j$&$a_k$}(i<j<k)$的最大值 枚举$i$, 从高位到低位贪心, 那么问题就转化为给定$x$ ...
- JNI创建共享内存导致JVM terminated的问题解决(segfault,shared memory,内存越界,内存泄漏,共享内存)
此问题研究了将近一个月,最终发现由于JNI不支持C中创建共享内存而导致虚拟机无法识别这块共享内存,造成内存冲突,最终虚拟机崩溃. 注意:JNI的C部分所使用的内存也是由JVM创建并管理的,所以C创建了 ...
- Fortify漏洞之Open Redirect(开放式重定向)
继续对Fortify的漏洞进行总结,本篇主要针对 Open Redirect(开放式重定向) 的漏洞进行总结,如下: 1.1.产生原因: 通过重定向,Web 应用程序能够引导用户访问同一应用程序内 ...
- sql server 2012 分页/dapper/C#拼sql/免储存过程/简易
sql server 2012新特性, 支持 OFFSET/FETCH分页, 就像mysql的limit, 比之前的各种top舒服多了, 看各位大佬们的测评文章说效率也是不相上下的, 有时候写个小工 ...
- 制作win10系统及安装win10系统
制作win10系统 1.登陆msdn,下载win10系统,打开迅雷下载器,复制完该段代码,直接开始下载,网址:https://msdn.itellyou.cn/ 2.下载软碟通,下载网址:https: ...
- 从零开始搭建vue移动端项目到上线
先来看一波效果图 初始化项目 1.在安装了node.js的前提下,使用以下命令 npm install --g vue-cli 2.在将要构建项目的目录下 vue init webpack mypro ...
- Docker Compose编排工具部署lnmp实践及理论(详细)
目录 一.理论概述 编排 部署 Compose原理 二.使用docker compose 部署lnmp 三.测试 四.总结 一.理论概述 Docker Compose是一个定义及运行多个Docker容 ...
- mysql(函数,存储过程,事务,索引)
函数 MySQL中提供了许多内置函数: 内置函数 一.数学函数 ROUND(x,y) 返回参数x的四舍五入的有y位小数的值 RAND() 返回0到1内的随机值,可以通过提供一个参数(种子)使RAND( ...
- TLS1.3 PPT 整理
1.握手协议的目的是什么 建立共享秘钥(通常使用公钥加密).协商算法和模型以及加密使用的参数,验证身份. 2.记录协议 传输独立的信息,在堆成加密算法下保护数据传输 3.RSA Handshake S ...
- Vue 前后端分离系统中遇到跨域问题
https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Access_control_CORS I Your application is running ...