这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. PAT A1144 The Missing Number (20 分)——set

    Given N integers, you are supposed to find the smallest positive integer that is NOT in the given li ...

  2. Luogu P2661 [NOIP2015] 信息传递

    qwq 今天做完并查集突然想起来这道以前做的好(shui)题, 虽然是黄题,但是是并查集一个比较特别的用法 这道题大概可以用求最小环的方式来做,但是从直觉上果然还是并查集w 乍一看只要求出“父→子”即 ...

  3. day83

    今日内容 rest_framework序列化 首先序列化需要对写一个类继承serializers.Serializer 方式一:在models的publish写一个__str__方法返回出版社名字 p ...

  4. Hive 创建表

    创建表的三种方式: 方式一:新建表结构 CREATE TABLE emp( empno int, ename string ) ROW FORMAT DELIMITED FIELDS TERMINAT ...

  5. 解决 在Android开发上使用KSOAP2上传大图片到服务器经常报错的问题

    原文首发我的主力博客 http://anforen.com/wp/2017/04/android_ksoap2_unexpected_type_position_end_document_null_j ...

  6. Redis对象占用内存分析

    当你往Redis中插入了一系统对象,如何分析这些对象的占用情况? 1.我们可以在Redis的控制台使用info命令来查看各项指标,其中有一项是Memory,可以通过存储前后的used_memory差异 ...

  7. Python从菜鸟到高手(5):数字

    1 基础知识   Python语言与其他编程语言一样,也支持四则运算(加.减.乘.除),以及圆括号运算符.在Python语言中,数字分为整数和浮点数.整数就是无小数部分的数,浮点数就是有小数部分的数. ...

  8. 【JVM.3】虚拟机性能监控与故障处理工具

    一.概述 经过前面两章对于虚拟机内存分配与回收技术各方面的介绍,相信读者已经建立了一套比较完整的理论基础.理论总是作为指导实践的工具,能把这些执行应用到实际工作中才是我们的最终目的.接下来我们会从实践 ...

  9. 最近新明白的SQL小知识

    1.partition by和order by 先看三个小需求: ①查询出各个类编号的书本的数量. select count (类编号) as 数量, 类编号 from Books group by ...

  10. 访谈:BugPhobia’s Brief Communication

    0x01 :采访的学长简介 If you weeped for the missing sunset, you would miss all the shining stars 梁野,北京航空航天大学 ...