(一):进入GitHub下载模型--》下载地址

因为我们需要slim模块,所以将包中的slim文件夹复制出来使用。

(1):在slim中新建images文件夹存放图片集

(2):新建model文件夹用来放模型

(3):在datasets文件夹中新建myimages.py文件

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
""" from __future__ import absolute_import
from __future__ import division
from __future__ import print_function import os
import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 这里根据自己的训练集内容进行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
} def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers. Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type. Returns:
A `Dataset` namedtuple. Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
} items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
} decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)

myimages.py

(4):修改dataset_factory.py

from datasets import myimages

datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'myimages':myimages, # 这一句为添加的内容
}

添加的内容

(二):对图片进行处理,生成tfrecord格式的文件。

import tensorflow as tf
import os
import random
import math
import sys #验证集数量
_NUM_TEST = 500
#随机种子
_RANDOM_SEED = 0
#数据块数目
_NUM_SHARDS = 5
#数据集路径
DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/"
#标签文件名字
LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename) #判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train', 'test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True #获取所有文件以及分类
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir, filename)
#判断该路径是否为目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename) photo_filenames = []
#循环每个分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory, filename)
#把图片加入图片列表
photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values):
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id):
#Abstract base class for protocol messages.
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
})) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME):
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name)) #把数据转为TFRecord格式
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
assert split_name in ['train', 'test']
#计算每个数据块有多少数据
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord文件的路径+名字
output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每一个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每一个数据块最后的位置
end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
for i in range(start_ndx, end_ndx):
try:
sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 这里一定要rb否则会出现编码错误
#获得图片的类别名称
class_name = os.path.basename(os.path.dirname(filenames[i]))
#找到类别名称对应的id
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data, b'jpg', class_id)
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print("Could not read:",filenames[i])
print("Error:",e)
print("Skip it\n") sys.stdout.write('\n')
sys.stdout.flush() if __name__ == '__main__':
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print('tfcecord文件已存在')
else:
#获得所有图片以及分类
photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR)
#把分类转为字典格式,类似于{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0}
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把数据切分为训练集和测试集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_TEST:]
testing_filenames = photo_filenames[:_NUM_TEST] #数据转换
_convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR)
_convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #输出labels文件
labels_to_class_names = dict(zip(range(len(class_names)), class_names))
write_label_file(labels_to_class_names, DATASET_DIR)

生成tfrecord

(三):新建批处理文件,开始训练模型

python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause 注释:
第一行表示运行训练文件,路径为全路径
第二行表示模型存放位置
第三行为创建的myimages文件名
第四行为使用的训练集
第五行为数据集所在的位置
第六行为批次大小,默认为32,看个人GPU,我用10
第七行为训练次数,默认无限次
第八行为使用模型名称

批处理文件

TensorFlow(十八):从零开始训练图片分类模型的更多相关文章

  1. TensorFlow(十七):训练自己的图片分类模型

    (一)下载inception-v3--见TensorFlow(十四) (二)准备训练用的图片集,因为我没有图片集,所以写了个自动抓取百度图片的脚本-见抓取百度图片 (三)创建retrain.py文件, ...

  2. 使用tensorflow的retrain.py训练图片分类器

    参考 https://hackernoon.com/creating-insanely-fast-image-classifiers-with-mobilenet-in-tensorflow-f030 ...

  3. 用C++调用tensorflow在python下训练好的模型(centos7)

    本文主要参考博客https://blog.csdn.net/luoyexuge/article/details/80399265 [1] bazel安装参考:https://blog.csdn.net ...

  4. NLP(十八)利用ALBERT提升模型预测速度的一次尝试

    前沿   在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,笔者介绍了如何利用tensorflow-serving部署来部署深度模型模型,在那篇文章中,笔者利用k ...

  5. PyTorch ImageNet 基于预训练六大常用图片分类模型的实战

    微调 Torchvision 模型 在本教程中,我们将深入探讨如何对 torchvision 模型进行微调和特征提取,所有这些模型都已经预先在1000类的Imagenet数据集上训练完成.本教程将深入 ...

  6. 用Pytorch训练MNIST分类模型

    本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...

  7. Tensorflow 使用slim框架下的分类模型进行分类

    Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object ...

  8. 【emWin】例程十八:jpeg图片显示

    说明:1.将文件拷入SD卡内即可在指定位置绘制jpeg图片文件,不必加载到储存器.     由于jpeg格式文件显示时需要进行解压缩,耗用动态内存,iCore3所有模块受emwin缓存的限制,jpeg ...

  9. 源码分析——迁移学习Inception V3网络重训练实现图片分类

    1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...

随机推荐

  1. 细说浏览器输入URL后发生了什么

    本文摘要: 1.DNS域名解析: 2.建立TCP连接: 3.发送HTTP请求: 4.服务器处理请求: 5.返回响应结果: 6.关闭TCP连接: 7.浏览器解析HTML: 8.浏览器布局渲染: 总结   ...

  2. CentOS 7忘记了root密码解决方案

    1.启动系统,在选择进入系统的界面按“e”进入编辑页面 2.按向下键,找到以“Linux16”开头的行,在该行的最后面输入“init=/bin/sh”  3.按“ctrl+X”组合键进入单用户模式 4 ...

  3. 8. Java的运算符

    计算机的最基本用途之一就是执行数学运算,作为一门计算机语言,Java也提供了一套丰富的运算符来操纵变量. 我们把运算符具体分为:算数运算符,比较运算符,逻辑运算符,位运算符,赋值运算符,条件运算符,i ...

  4. 数据结构之链表(LinkedList)(一)

    链表(Linked List)介绍 链表是有序的列表,但是它在内存中是存储如下 1)链表是以节点方式存储的,是链式存储 2)每个节点包含data域(value),next域,指向下一个节点 3)各个节 ...

  5. 让image居中对齐,网页自适应

    <div class="page4_content"> <div class="page4_box"> <div class=&q ...

  6. nginx日志文件的配置

    文章来源 运维公会: nginx日志文件的配置 1.日志介绍 nginx有两种日志,一种是访问日志,一种是错误日志. 访问日志中记录的是客户端对服务器的所有请求. 错误日志中记录的是在访问过程中,因为 ...

  7. Python面向对象之多态、封装

    一.多态 超过一个子类继承父类,出现了多种的形态. 例如,动物种类出现了多种形态,比如猫.狗.猪 class Animal:pass class Cat(Animal):pass class Dog( ...

  8. 算法学习:我终于明白二分查找的时间复杂度为什么是O(logn)了

    最近发现了个好东西,就是一个学算法的好东西,是网易公开课的一个视频. 直通车 这是麻省理工学院的公开课,有中英字幕,感谢网易.. 也可以在App把视频缓存下来之后再放到电脑上面看,因为我这样可以倍速, ...

  9. 【DevOps】在Rancher2中启动Docker-Registry仓库服务

    准备 拥有Rancher2环境,已经在Rancher2配置Kubernetes集群 拥有域名,拥有SSL证书,可以自行在阿里云申请 启动Docker-Registry仓库服务 第一步:进入集群应用 第 ...

  10. 转载: Redis面试常问的问题

    https://www.cnblogs.com/javazhiyin/p/9842571.html 近,阿音在为接下来的一场面试做准备,其中的内容包括redis,而且redis是重点内容. Redis ...