这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. .Net修改网站项目调试时的虚拟目录(未验证)

    有些项目需要在IIS发布的时候,将网站发布到虚拟目录,为了保持调试和发布的路径同一,一般会修改VS调试的虚拟目录 一.Web应用程序 Web应用程序的修改方式非常简单,在解决方案资源管理器->项 ...

  2. go标准库的学习-net/http

    参考:https://studygolang.com/pkgdoc 概念解释: request:用户请求的信息,用来解析用户的请求信息,包括post.get.cookie.url等信息 respons ...

  3. spring+springmvc+hibernate整合实例

    最近要弄一个自动化生成表及其实体对应的增删改查的框架,于是我想到了hibernate,hibernate就有根据实体自动建表,而且增删改查,都不需要想mybatis那样在xml文件中配置. 不过怎样让 ...

  4. 【转】Kaggle注册问题-验证码和手机短信

    注册和登录Kaggle时验证码无法显示问题 参考:https://blog.csdn.net/zhuisaozhang1292/article/details/81529981 应用FQ软件需要时时关 ...

  5. 项目Alpha冲刺 3

    作业描述 课程: 软件工程1916|W(福州大学) 作业要求: 项目Alpha冲刺(团队) 团队名称: 火鸡堂 作业目标: 介绍第三天冲刺的项目进展.问题困难和心得体会 1.团队信息 队名:火鸡堂 队 ...

  6. AI 逻辑回归

    逻辑回归 参考链接 https://zhuanlan.zhihu.com/p/44591359

  7. 云主机被拿去挖矿,cpu暴涨,tcp连接突增

    1.云主机被拿去挖矿,cpu暴涨,tcp连接突增 2.现象:top -c 3.然后我再查看pstree进程树 4.查找文件来源 ind  / -name '*suppoie*' 5. 然后删除 sup ...

  8. oracle 相除后保留指定位数小数round()

    ) xxx from dual; XXX----------    3.8871

  9. 画线函数Glib_Line算法的研究

      在这里首先先简单把我对函数的功能的理解阐述一下,方便后面的分析:Glib_Line函数实现的功能是通过参数给定(x1,y1,x2,y2,color),来确定起点(x1,y1)和终点(x2,y2)两 ...

  10. openssl生成签名与验证签名

    继上一篇RSA对传输信息进行加密解密,再写个生成签名和验证签名. 一般,安全考虑,比如接入支付平台时,请求方和接收方要互相验证是否是你,就用签名来看. 签名方式一般两种,对称加密和非对称加密.对称加密 ...