# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training. import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph() # Load the iris data
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data]) # Declare batch size
batch_size = 20 # Create graph
sess = tf.Session() # Declare placeholders
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32) # Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1])) # Add model to graph:
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add) # Add classification loss (cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target) # Create Optimizer
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy) # Initialize variables
init = tf.global_variables_initializer()
sess.run(init) # Run Loop
for i in range(1000):
rand_index = np.random.choice(len(iris_2d), size=batch_size)
#rand_x = np.transpose([iris_2d[rand_index]])
rand_x = iris_2d[rand_index]
rand_x1 = np.array([[x[0]] for x in rand_x])
rand_x2 = np.array([[x[1]] for x in rand_x])
#rand_y = np.transpose([binary_target[rand_index]])
rand_y = np.array([[y] for y in binary_target[rand_index]])
sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
if (i+1)%200==0:
print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b))) # Visualize Results
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b) # Create fitted line
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
ablineValues.append(slope*i+intercept) # Plot the fitted line over the data
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

tensorflow 线性回归解决 iris 2分类的更多相关文章

  1. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  2. Chinese-Text-Classification,用卷积神经网络基于 Tensorflow 实现的中文文本分类。

    用卷积神经网络基于 Tensorflow 实现的中文文本分类 项目地址: https://github.com/fendouai/Chinese-Text-Classification 欢迎提问:ht ...

  3. 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    https://zhuanlan.zhihu.com/p/25928551 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文题目便是文本分类问题,趁此机会总结下文本分类 ...

  4. TensorFlow基础笔记(3) cifar10 分类学习

    TensorFlow基础笔记(3) cifar10 分类学习 CIFAR-10 is a common benchmark in machine learning for image recognit ...

  5. [转] 用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践

    转自知乎上看到的一篇很棒的文章:用深度学习(CNN RNN Attention)解决大规模文本分类问题 - 综述和实践 近来在同时做一个应用深度学习解决淘宝商品的类目预测问题的项目,恰好硕士毕业时论文 ...

  6. 在 Flutter 中使用 TensorFlow Lite 插件实现文字分类

    如果您希望能有一种简单.高效且灵活的方式把 TensorFlow 模型集成到 Flutter 应用里,那请您一定不要错过我们今天介绍的这个全新插件 tflite_flutter.这个插件的开发者是 G ...

  7. TensorFlow实现多层感知机MINIST分类

    TensorFlow实现多层感知机MINIST分类 TensorFlow 支持自动求导,可以使用 TensorFlow 优化器来计算和使用梯度.使用梯度自动更新用变量定义的张量.本文将使用 Tenso ...

  8. tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    # Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...

  9. 用决策树(CART)解决iris分类问题

    首先先看Iris数据集 Sepal.Length--花萼长度 Sepal.Width--花萼宽度 Petal.Length--花瓣长度 Petal.Width--花瓣宽度 通过上述4中属性可以预测花卉 ...

随机推荐

  1. php 源码编译

    https://cyberpersons.com/2016/08/28/install-nginx-php-php-fpm-mysql-source-run-wordpress-site-ubuntu ...

  2. G 全然背包

    <span style="color:#3333ff;">/* /* _________________________________________________ ...

  3. 文件系统之-JAVA Sftp远程操作:

    转载:http://blog.csdn.net/lee272616/article/details/52789018 java远程操作文件服务器(linux),使用sftp协议版本会持续更新,当前版本 ...

  4. C 标准库 - <errno.h>

    C 标准库 - <errno.h> 简介 C 标准库的 errno.h 头文件定义了整数变量 errno,它是通过系统调用设置的,在错误事件中的某些库函数表明了什么发生了错误.该宏扩展为类 ...

  5. JAVA Timer定时器使用方法

    JAVA  Timer 定时器测试 MyTask.java:package com.timer; import java.text.SimpleDateFormat;import java.util. ...

  6. 【java读书笔记】——java的异常处理

    程序在实际环境的执行过程中.安全成为须要首先考虑的重要因素之中的一个.这也是用户和程序猿最关心的问题.同一时候,Java语言健壮性也体如今了可以及时有效地处理程序中的错误.准确的说是Java的异常处理 ...

  7. 网络协议之rtp---h264的rtp网络协议实现

    完整的C/S架构的基于RTP/RTCP的H.264视频传输方案.此方案中,在服务器端和客户端分别进行了功能模块设计.服务器端:RTP封装模块主要是对H.264码流进行打包封装:RTCP分析模块负责产牛 ...

  8. ios 常见错误整理 持续更新

    本文转载至 http://blog.csdn.net/yesjava/article/details/8086185  1. mutating method sent to immutable obj ...

  9. ZOJ 3551 Bloodsucker <概率DP>

    题目:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3551 题意:开始有N-1个人和一个吸血鬼, 每天有两个生物见面,当人 ...

  10. 九度OJ 1087:约数的个数 (数字特性)

    时间限制:1 秒 内存限制:32 兆 特殊判题:否 提交:7349 解决:2306 题目描述: 输入n个整数,依次输出每个数的约数的个数 输入: 输入的第一行为N,即数组的个数(N<=1000) ...