利用卷积神经网络训练图像数据分为以下几个步骤

  1. 读取图片文件
  2. 产生用于训练的批次
  3. 定义训练的模型(包括初始化参数,卷积、池化层等参数、网络)
  4. 训练

1 读取图片文件

 def get_files(filename):
class_train = []
label_train = []
for train_class in os.listdir(filename):
for pic in os.listdir(filename+train_class):
class_train.append(filename+train_class+'/'+pic)
label_train.append(train_class)
temp = np.array([class_train,label_train])
temp = temp.transpose()
#shuffle the samples
np.random.shuffle(temp)
#after transpose, images is in dimension 0 and label in dimension 1
image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(i) for i in label_list]
#print(label_list)
return image_list,label_list

  这里文件名作为标签,即类别(其数据类型要确定,后面要转为tensor类型数据)。

  然后将image和label转为list格式数据,因为后边用到的的一些tensorflow函数接收的是list格式数据。

2 产生用于训练的批次

 def get_batches(image,label,resize_w,resize_h,batch_size,capacity):
#convert the list of images and labels to tensor
image = tf.cast(image,tf.string)
label = tf.cast(label,tf.int64)
queue = tf.train.slice_input_producer([image,label])
label = queue[1]
image_c = tf.read_file(queue[0])
image = tf.image.decode_jpeg(image_c,channels = 3)
#resize
image = tf.image.resize_image_with_crop_or_pad(image,resize_w,resize_h)
#(x - mean) / adjusted_stddev
image = tf.image.per_image_standardization(image) image_batch,label_batch = tf.train.batch([image,label],
batch_size = batch_size,
num_threads = 64,
capacity = capacity)
images_batch = tf.cast(image_batch,tf.float32)
labels_batch = tf.reshape(label_batch,[batch_size])
return images_batch,labels_batch

  首先使用tf.cast转化为tensorflow数据格式,使用tf.train.slice_input_producer实现一个输入的队列。

  label不需要处理,image存储的是路径,需要读取为图片,接下来的几步就是读取路径转为图片,用于训练。

  CNN对图像大小是敏感的,第10行图片resize处理为大小一致,12行将其标准化,即减去所有图片的均值,方便训练。

  接下来使用tf.train.batch函数产生训练的批次。

  最后将产生的批次做数据类型的转换和shape的处理即可产生用于训练的批次。

3 定义训练的模型

(1)训练参数的定义及初始化

 def init_weights(shape):
return tf.Variable(tf.random_normal(shape,stddev = 0.01))
#init weights
weights = {
"w1":init_weights([3,3,3,16]),
"w2":init_weights([3,3,16,128]),
"w3":init_weights([3,3,128,256]),
"w4":init_weights([4096,4096]),
"wo":init_weights([4096,2])
} #init biases
biases = {
"b1":init_weights([16]),
"b2":init_weights([128]),
"b3":init_weights([256]),
"b4":init_weights([4096]),
"bo":init_weights([2])
}

  CNN的每层是y=wx+b的决策模型,卷积层产生特征向量,根据这些特征向量带入x进行计算,因此,需要定义卷积层的初始化参数,包括权重和偏置。其中第8行的参数形状后边再解释。

(2)定义不同层的操作

 def conv2d(x,w,b):
x = tf.nn.conv2d(x,w,strides = [1,1,1,1],padding = "SAME")
x = tf.nn.bias_add(x,b)
return tf.nn.relu(x) def pooling(x):
return tf.nn.max_pool(x,ksize = [1,2,2,1],strides = [1,2,2,1],padding = "SAME") def norm(x,lsize = 4):
return tf.nn.lrn(x,depth_radius = lsize,bias = 1,alpha = 0.001/9.0,beta = 0.75)

  这里只定义了三种层,即卷积层、池化层和正则化层

(3)定义训练模型

 def mmodel(images):
l1 = conv2d(images,weights["w1"],biases["b1"])
l2 = pooling(l1)
l2 = norm(l2)
l3 = conv2d(l2,weights["w2"],biases["b2"])
l4 = pooling(l3)
l4 = norm(l4)
l5 = conv2d(l4,weights["w3"],biases["b3"])
#same as the batch size
l6 = pooling(l5)
l6 = tf.reshape(l6,[-1,weights["w4"].get_shape().as_list()[0]])
l7 = tf.nn.relu(tf.matmul(l6,weights["w4"])+biases["b4"])
soft_max = tf.add(tf.matmul(l7,weights["wo"]),biases["bo"])
return soft_max

  模型比较简单,使用三层卷积,第11行使用全连接,需要对特征向量进行reshape,其中l6的形状为[-1,w4的第1维的参数],因此,将其按照“w4”reshape的时候,要使得-1位置的大小为batch_size,这样,最终再乘以“wo”时,最终的输出大小为[batch_size,class_num]

(4)定义评估量

 def loss(logits,label_batches):
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=label_batches)
cost = tf.reduce_mean(cross_entropy)
return cost

  首先定义损失函数,这是用于训练最小化损失的必需量

 def get_accuracy(logits,labels):
acc = tf.nn.in_top_k(logits,labels,1)
acc = tf.cast(acc,tf.float32)
acc = tf.reduce_mean(acc)
return acc

  评价分类准确率的量,训练时,需要loss值减小,准确率增加,这样的训练才是收敛的。

(5)定义训练方式

 def training(loss,lr):
train_op = tf.train.RMSPropOptimizer(lr,0.9).minimize(loss)
return train_op

  有很多种训练方式,可以自行去官网查看,但是不同的训练方式可能对应前面的参数定义不一样,需要另行处理,否则可能报错。

4 训练

 def run_training():
data_dir = 'C:/Users/wk/Desktop/bky/dataSet/'
image,label = inputData.get_files(data_dir)
image_batches,label_batches = inputData.get_batches(image,label,32,32,16,20)
p = model.mmodel(image_batches)
cost = model.loss(p,label_batches)
train_op = model.training(cost,0.001)
acc = model.get_accuracy(p,label_batches) sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess,coord = coord) try:
for step in np.arange(1000):
print(step)
if coord.should_stop():
break
_,train_acc,train_loss = sess.run([train_op,acc,cost])
print("loss:{} accuracy:{}".format(train_loss,train_acc))
except tf.errors.OutOfRangeError:
print("Done!!!")
finally:
coord.request_stop()
coord.join(threads)
sess.close()

  

tensorflow训练自己的数据集实现CNN图像分类1的更多相关文章

  1. tensorflow训练自己的数据集实现CNN图像分类2(保存模型&测试单张图片)

    神经网络训练的时候,我们需要将模型保存下来,方便后面继续训练或者用训练好的模型进行测试.因此,我们需要创建一个saver保存模型. def run_training(): data_dir = 'C: ...

  2. 在C#下使用TensorFlow.NET训练自己的数据集

    在C#下使用TensorFlow.NET训练自己的数据集 今天,我结合代码来详细介绍如何使用 SciSharp STACK 的 TensorFlow.NET 来训练CNN模型,该模型主要实现 图像的分 ...

  3. 【Tensorflow系列】使用Inception_resnet_v2训练自己的数据集并用Tensorboard监控

    [写在前面] 用Tensorflow(TF)已实现好的卷积神经网络(CNN)模型来训练自己的数据集,验证目前较成熟模型在不同数据集上的准确度,如Inception_V3, VGG16,Inceptio ...

  4. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  5. tensorflow中使用mnist数据集训练全连接神经网络-学习笔记

    tensorflow中使用mnist数据集训练全连接神经网络 ——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师 前期准备:mnist数据集下载,并存入data目录: ...

  6. 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)

    如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...

  7. Pytorch和CNN图像分类

    Pytorch和CNN图像分类 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强大的GPU加速 ...

  8. 使用TensorFlow训练自己的语音识别AI

    这次来训练一个基于CNN的语音识别模型.训练完成后,我们将尝试将此模型用于Hotword detection. 人类是怎样听懂一句话的呢?以汉语为例,当听到"wo shi"的录音时 ...

  9. 使用Tensorflow训练自己的数据

    训练自己的数据集(以bottle为例):   1.准备数据 文件夹结构: models ├── images ├── annotations │ ├── xmls │ └── trainval.txt ...

随机推荐

  1. linux逻辑卷管理 (LVM)(转)

    1.什么是 LVM LVM 是逻辑盘卷管理(Logical Volume Manager)的简称,它是 Linux 环境下对磁盘分区进行管理的一种机制,LVM 是建立在硬盘和分区之上的一个逻辑层,来为 ...

  2. php intval的取值范围:与操作系统相关

    php intval的取值范围:与操作系统相关,32位系统上为-2147483648到2147483647,64位系统上为-9223372036854775808到922337203685477580 ...

  3. java后台设计简单的json数据接口,设置可跨域访问,前端ajax获取json数据

    在开发的过程中,有时候我们需要设计一个数据接口.有时候呢,数据接口和Web服务器又不在一起,所以就有跨域访问的问题. 第一步:简单的设计一个数据接口. 数据接口,听起来高大上,其实呢就是一个简单的Se ...

  4. HDU 1281 - 棋盘游戏 - [二分图最大匹配]

    题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=1281 Time Limit: 2000/1000 MS (Java/Others) Mem ...

  5. NEFU 118 - n!后面有多少个0 & NEFU 119 - 组合素数 - [n!的素因子分解]

    首先给出一个性质: n!的素因子分解中的素数p的幂为:[ n / p ] + [ n / p² ] + [ n / p³ ] + …… 举例证明: 例如我们有10!,我们要求它的素因子分解中2的幂: ...

  6. PHP漏洞

    http://os.51cto.com/art/201204/328766.htm 针对PHP的网站主要存在下面几种攻击方式: 1.命令注入(Command Injection) 2.eval注入(E ...

  7. Map集合遍历

    Map<String,String> map = new HashMap<String, String>(); map.put("1","java ...

  8. AT2043 AND Grid 构造

    正解:构造 解题报告: 传送门传送门! 这题psj讲了俩做法,一个是最常见的解法,还一种还不知道484对的QAQ 然后先把psj讲的不知正确性的做法港下QwQ 大概就是说,第一个图,先把底给染完 然后 ...

  9. CentOS7 firewall防火墙配置笔记

    开启端口 # firewall-cmd --zone=public --add-port=/tcp --permanent 命令含义:         --zone #作用域         --ad ...

  10. 【Pyton】【小甲鱼】永久存储:腌制一缸美味的泡菜

    pickle(泡菜): picking:将对象转换为二进制 unpicking:将二进制转换为对象 1 >>> import pickle 2 #picking:对象导入到文件中(二 ...