TensorFlow(十八):从零开始训练图片分类模型
(一):进入GitHub下载模型--》下载地址
因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。
(1):在slim中新建images文件夹存放图片集
(2):新建model文件夹用来放模型
(3):在datasets文件夹中新建myimages.py文件
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
""" from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import os
import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 这里根据自己的训练集内容进行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
} def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers. Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type. Returns:
A `Dataset` namedtuple. Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
} items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
} decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
myimages.py
(4):修改dataset_factory.py
from datasets import myimages
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'myimages':myimages, # 这一句为添加的内容
}
添加的内容
(二):对图片进行处理,生成tfrecord格式的文件。
import tensorflow as tf
import os
import random
import math
import sys #验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块数目
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/"
#标签文件名字
LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename) #判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True #获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename) photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id):
#Abstract base class for protocol messages.
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
})) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name)) #把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
assert split_name in ['train', 'test']
#计算每个数据块有多少数据
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每一个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每一个数据块最后的位置
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的id
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data, b'jpg', class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:",filenames[i])
print("Error:",e)
print("Skip it\n") sys.stdout.write('\n')
sys.stdout.flush() if __name__ == '__main__':
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print('tfcecord文件已存在')
else:
#获得所有图片以及分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
#把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_TEST:]
testing_filenames = photo_filenames[:_NUM_TEST] #数据转换
_convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
_convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #输出labels文件
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)
生成tfrecord
(三):新建批处理文件,开始训练模型
python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause 注释:
第一行表示运行训练文件,路径为全路径
第二行表示模型存放位置
第三行为创建的myimages文件名
第四行为使用的训练集
第五行为数据集所在的位置
第六行为批次大小,默认为32,看个人GPU,我用10
第七行为训练次数,默认无限次
第八行为使用模型名称
批处理文件
TensorFlow(十八):从零开始训练图片分类模型的更多相关文章
- TensorFlow(十七):训练自己的图片分类模型
(一)下载inception-v3--见TensorFlow(十四) (二)准备训练用的图片集,因为我没有图片集,所以写了个自动抓取百度图片的脚本-见抓取百度图片 (三)创建retrain.py文件, ...
- 使用tensorflow的retrain.py训练图片分类器
参考 https://hackernoon.com/creating-insanely-fast-image-classifiers-with-mobilenet-in-tensorflow-f030 ...
- 用C++调用tensorflow在python下训练好的模型(centos7)
本文主要参考博客https://blog.csdn.net/luoyexuge/article/details/80399265 [1] bazel安装参考:https://blog.csdn.net ...
- NLP(十八)利用ALBERT提升模型预测速度的一次尝试
前沿 在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,笔者介绍了如何利用tensorflow-serving部署来部署深度模型模型,在那篇文章中,笔者利用k ...
- PyTorch ImageNet 基于预训练六大常用图片分类模型的实战
微调 Torchvision 模型 在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成.本教程将深入 ...
- 用Pytorch训练MNIST分类模型
本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...
- Tensorflow 使用slim框架下的分类模型进行分类
Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object ...
- 【emWin】例程十八:jpeg图片显示
说明:1.将文件拷入SD卡内即可在指定位置绘制jpeg图片文件,不必加载到储存器. 由于jpeg格式文件显示时需要进行解压缩,耗用动态内存,iCore3所有模块受emwin缓存的限制,jpeg ...
- 源码分析——迁移学习Inception V3网络重训练实现图片分类
1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...
随机推荐
- STL中的函数对象实现负数的定义
// // main.cpp // STL中的函数对象 // // Created by mac on 2019/5/2. // Copyright © 2019年 mac. All rights r ...
- zabbix 数据库分表操作
近期zabbix数据库占用的io高,在页面查看图形很慢,而且数据表已经很大,将采用把数据库的数据目录移到新的磁盘,将几个大表进行分表操作 一.数据迁移: 1.数据同步到新的磁盘上,先停止mysql(不 ...
- python练习:面向对象1
面向对象习题: 一:定义一个学生类.有下面的类属性: 1 姓名 2 年龄 3 成绩(语文,数学,英语)[每课成绩的类型为整数] 类方法: 1 获取学生的姓名:get_name() 返回类型:str 2 ...
- PowerShell将运行结果保存为文件
1. Out-File 示例: get-process | Out-File -filepath a.txt 跟“>”是一样的效果 输出 为普通文本 2. Export-Clixml 示例: g ...
- OpenCvSharp 图像旋转
/// <summary> /// 图像旋转 /// </summary> private Mat MatRotate(Mat src, float angle) { Mat ...
- 日常hive遇到的问题
1 hive中的复杂数据类型数据如何导入(array) 创建hive表 create table temp.dws_search_by_program_set_count_his( program_s ...
- 命令行参数 && json 协议 && 自定义 error 类型
命令行参数 在写代码的时候,在运行程序做一些初始化操作的时候,往往会通过命令行传参数到程序中,那么就会用到命令行参数 例如,指定程序运行的模式和级别: go run HTTPServer.go --m ...
- 【转载】 C#中日期类型DateTime的日期加减操作
在C#开发过程中,DateTime数据类型用于表示日期类型,可以通过DateTime.Now获取当前服务器时间,同时日期也可以像数字一样进行加减操作,如AddDay方法可以对日期进行加减几天的操作,A ...
- 【转】SpringCloud学习笔记(一)——基础
什么是微服务架构 简单地说,微服务是系统架构上的一种设计风格,它的主旨是将一个原本独立的系统拆分成多个小型服务,这些小型服务都在各自独立的进程中运行,服务之间通过基于HTTP的RESTful API进 ...
- 解决在web.xml中配置server服务器启动失败问题
一.问题"Server Tomacat v8.5 Server at locallhost failed to start" 二.解决方法:删除注释@webServlet 三.分析 ...