这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. Qt中 .pro 文件和 .pri 文件简介

    *.pro 这是一个典型的Qt示例程序的.pro文件(propriprfprl.pro): TEMPLATE = app CONFIG += QT QT += core gui TARGET = pr ...

  2. all与any的用法

    all函数:检测矩阵中是否全为非零元素 any函数:检测矩阵中是否有非零元素,如果有,则返回1,否则,返回0.用法和all一样 语法: B = all(A) B = all(A, dim) 复制代码 ...

  3. ES5与ES6对比

    ES5与ES6对比 1. 模块引用 1.在ES5里,引入React包基本通过require进行,代码类似这样: // ES5 var React = require('react'); var { C ...

  4. python基础学习第三天

    #变量存储在内存中的值.这就意味着在创建变量时会在内存中开辟一个空间#基于变量的数据类型,解释器会分配指定内存,并决定什么数据可以被存储在内存中#变量可以指定不同的数据类型,这些变量可以存储整数.小数 ...

  5. Echo团队Alpha冲刺随笔 - 第一天

    项目冲刺情况 进展 每个人开始搭建自己要用的各种框架.库,基本实现了登录功能 问题 除了框架使用问题外,暂未遇到其他疑难杂症 心得 今天有一个还可以的开头,相信后续会挺顺利的 今日会议内容 黄少勇 今 ...

  6. Python import用法

    官方文档说明: Python code in one module gains access to the code in another module by the process of impor ...

  7. 【Codeforces Round 1110】Codeforces Global Round 1

    Codeforces Round 1110 这场比赛只做了\(A\).\(B\).\(C\),排名\(905\),不好. 主要的问题在\(D\)题上,有\(505\)人做出,但我没做出来. 考虑的时候 ...

  8. TCP/IP协议--TCP的超时和重传

    TCP是可靠传输.可靠之一体现在收到数据后,返回去一个确认.但是不能完全避免的是,数据和确认都可能丢失.解决这个办法就是,提供一个发送的重传定时器:如果定时器溢出时还没收到确认,它就重传这个报文段. ...

  9. MFC入门(三)-- MFC图片/文字控件(循环显示文字和图片的小程序)

    惯例附上前几个博客的链接: MFC入门(一)简单配置:http://blog.csdn.net/zmdsjtu/article/details/52311107 MFC入门(二)读取输入字符:http ...

  10. UOJ67 新年的毒瘤 tarjan

    题目传送门 题意:给出一个$N$个点.$M$条边的无向图,找出其中的点,满足去掉该点与和它相连的边之后,这个图会变成一棵树.$N , M \leq 10^5$ 说是毒瘤,真的不毒瘤 思考一下,我们需要 ...