这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. LBS

  2. MySQL数据备份之mysqldump使用(转)

    文章转自 :https://www.cnblogs.com/jpfss/p/7867668.html mysqldump常用于MySQL数据库逻辑备份. 1.各种用法说明 A. 最简单的用法: mys ...

  3. js == 与 === 的区别[转]

    we文章转自http://blog.sina.com.cn/s/blog_4b32835b01014iv9.html 1.对于string,number等基础类型,==和===是有区别的 1)不同类型 ...

  4. nginx做负载均衡和tomcat简单集群

    Nginx做负载均衡和TOMCAT简单集群                1.下载安装nginx及其依赖包                                               ...

  5. 深入浅出的webpack构建工具--webpack4+vue+router项目架构(十四)

    阅读目录 一:vue-router是什么? 二:vue-router的实现原理 三:vue-router使用及代码配置 四:理解vue设置路由导航的两种方法. 五:理解动态路由和命名视图 六:理解嵌套 ...

  6. Android程序的反破解技术

    Android 程序的破解一般步骤如下:反编译.静态分析.动态调试.重编译.我们可以从这几个步骤着手反破解 反编译 我们可以查找反编译器的漏洞,从而使反编译器无法正确解析APK文件 静态分析 对jav ...

  7. 奇怪的组数length属性

    Java中的数组其实也是一个对象,但是确实是一个特殊的对象,实在是太特殊了,继承自Object, 多出一个属性length,改写了clone方法.   我debug了数组对象的运行时的Class对象, ...

  8. 【JVM.4】调优案例分析与实战

    之前已经介绍过处理Java虚拟机内存问题的知识与工具,在处理实际项目的问题时,除了知识与工具外,经验同样是一个很重要的因素.本章会介绍一些具有代表性的案例. 本章的内容推荐还是原文全篇看完的好,实在不 ...

  9. 老生常谈,函数柯里化(curring)

    柯里化这个概念确实晦涩难懂,没有深入思考过的人其实真的很难明白这是一个什么东西.看起来简单.简单到或许只需要一行代码: const curry = fn => (…args) => fn. ...

  10. 深入理解USB流量数据包的抓取与分析

    0x01 问题提出 在一次演练中,我们通过wireshark抓取了一个如下的数据包,我们如何对其进行分析? 0x02 问题分析 流量包是如何捕获的? 首先我们从上面的数据包分析可以知道,这是个USB的 ...