用 TensorFlow 实现 SVM 分类问题
这篇文章解释了底部链接的代码。
问题描述

如上图所示,有一些点位于单位正方形内,并做好了标记。要求找到一条线,作为分类的标准。这些点的数据在 inearly_separable_data.csv
文件内。
思路
最初的 SVM 可以形式化为如下:
\[\begin{equation}\min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2\\s.t.\ y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b)\geqslant 1,\ i = 1,2,\cdots ,m.\end{equation} \]
引入软间隔,可以在一定情况下避免过拟合的问题。
引入软间隔之后,问题转化为
\[\begin{equation}
\min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2 + C \sum_{i=1}^{N}max(0,1-y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b))
\end{equation}\]
代码
主要代码在 linear_svm.py
内,plot_boundary_on_data.py
负责画图。
一、引入库和声明
import tensorflow as tf
import numpy as np
import scipy.io as io
from matplotlib import pyplot as plt
import plot_boundary_on_data
二、 定义一些变量
# Global variables.
BATCH_SIZE = 100 # The number of training examples to use per training step.
# Define the flags useable from the command line.
tf.app.flags.DEFINE_string('train', None,
'File containing the training data (labels & features).')
tf.app.flags.DEFINE_integer('num_epochs', 1,
'Number of training epochs.')
tf.app.flags.DEFINE_float('svmC', 1,
'The C parameter of the SVM cost function.')
tf.app.flags.DEFINE_boolean('verbose', False, 'Produce verbose output.')
tf.app.flags.DEFINE_boolean('plot', True, 'Plot the final decision boundary on the data.')
FLAGS = tf.app.flags.FLAGS
包括每次训练使用的数据,称为一个 batch,大小定义为 BATCH_SIZE
。
train
是训练集文件的位置,这里是 inearly_separable_data.csv
。
num_epochs
是把所有训练集的数据使用几遍。把训练集的数据使用一遍称为一个 epoch。
svmC
即\((2)\)式中 \(C\)的大小。
三、读取训练数据
# Extract it into numpy matrices.
train_data,train_labels = extract_data(train_data_filename)
# Convert labels to +1,-1
train_labels[train_labels==0] = -1
# Get the shape of the training data.
train_size,num_features = train_data.shape
读出来的 train_data
是一个 [1000, 2] 的张量,样本的有两个属性,train_labels
是一个 [1000, 1] 的张量。
在读取过程中用到了 numpy
的接口。
标准的 SVM 的标记为 \(\{-1, 1\}\),而文件中标记为 \(\{0, 1\}\)。因此需要做一次转换。
四、构造网络结构
x = tf.placeholder("float", shape=[None, num_features])
y = tf.placeholder("float", shape=[None,1])
W = tf.Variable(tf.zeros([num_features,1]))
b = tf.Variable(tf.zeros([1]))
y_raw = tf.matmul(x,W) + b
线性方程的最终表现形式是 \(\boldsymbol{\omega}^t\boldsymbol{x}+b=0\)。
给定一个样本数据 \(\boldsymbol{x}\),若 \(\boldsymbol{\omega}^t\boldsymbol{x}+b \geqslant 1\),则认为对应的分类为 1,然后和样本的标记对比,若标记为1,则分类正确;否则,分类错误。
若 \(\boldsymbol{\omega}^t\boldsymbol{x}+b \leqslant 1\),则认为对应的分类为 -1,然后和样本的标记对比,若标记为-1,则分类正确;否则,分类错误。
最终要求解的值是一个 shape 为 [2, 1] 的张量 \(W\) 和一个标量 \(b\)。
y_raw
是向量机判定的输出。
五、构造优化目标
regularization_loss = 0.5*tf.reduce_sum(tf.square(W))
hinge_loss = tf.reduce_sum(tf.maximum(tf.zeros([BATCH_SIZE,1]),
1 - y*y_raw));
svm_loss = regularization_loss + svmC*hinge_loss;
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(svm_loss)
即 \( \min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2 + C \sum_{i=1}^{N}max(0,1-y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b))\) 的代码表示。
指定用梯度下降法最小化 svm_loss
。
六、用精度来评价模型的好坏
predicted_class = tf.sign(y_raw);
correct_prediction = tf.equal(y,predicted_class)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
如果 y_raw
和样本的标记 y
同符号,即认为预测正确。用预测正确的比例来评价模型的好坏。
七、用数据训练模型
with tf.Session() as s:
# Run all the initializers to prepare the trainable parameters.
tf.initialize_all_variables().run()
# Iterate and train.
for step in xrange(num_epochs * train_size // BATCH_SIZE):
offset = (step * BATCH_SIZE) % train_size
batch_data = train_data[offset:(offset + BATCH_SIZE), :]
batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
train_step.run(feed_dict={x: batch_data, y: batch_labels})
print 'loss: ', svm_loss.eval(feed_dict={x: batch_data, y: batch_labels})
首先启动一个 session
,每次取 BATCH_SIZE
个数据来训练模型。即用batch_data
和 batch_lables
来训练一次,每次得到一个 svm_loss
的值。
运行结果
python linear_svm.py --train linearly_separable_data.csv --svmC 1 --verbose True --num_epochs 10
运行以上命令,指定把数据使用10轮,一次使用100个数据,因此可以得到100次迭代的结果。最后得到的结果及精度如下:

思考
- 指定
BATCH_SIZE
和num_epochs
是为了减少计算量。
根据数学理论,应该在整个训练数据集上进行梯度下降法的迭代,每一步迭代都应该选取所有训练数据集的样本。但是这样子做计算量太大,于是在每一次迭代时选用训练数据集的一部分作为输入。
这么做要求每一步迭代选取的数据子集的分布和总体分布一致,否则得不到正确的结果。
参考
用 TensorFlow 实现 SVM 分类问题的更多相关文章
- SVM原理以及Tensorflow 实现SVM分类(附代码)
1.1. SVM介绍 1.2. 工作原理 1.2.1. 几何间隔和函数间隔 1.2.2. 最大化间隔 - 1.2.2.0.0.1. \(L( {x}^*)\)对$ {x}^*$求导为0 - 1.2.2 ...
- Relation Extraction中SVM分类样例unbalance data问题解决 -松弛变量与惩罚因子
转载自:http://blog.csdn.net/yangliuy/article/details/8152390 1.问题描述 做关系抽取就是要从产品评论中抽取出描述产品特征项的target短语以及 ...
- SVM-支持向量机(二)非线性SVM分类
非线性SVM分类 尽管SVM分类器非常高效,并且在很多场景下都非常实用.但是很多数据集并不是可以线性可分的.一个处理非线性数据集的方法是增加更多的特征,例如多项式特征.在某些情况下,这样可以让数据集变 ...
- SVM-支持向量机(一)线性SVM分类
SVM-支持向量机 SVM(Support Vector Machine)-支持向量机,是一个功能非常强大的机器学习模型,可以处理线性与非线性的分类.回归,甚至是异常检测.它也是机器学习中非常热门的算 ...
- tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)
iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...
- tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)
# Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...
- 用tensorflow实现SVM
环境配置 win10 Python 3.6 tensorflow1.15 scipy matplotlib (运行时可能会遇到module tkinter的问题) sklearn 一个基于Python ...
- SVM分类与回归
SVM(支撑向量机模型)是二(多)分类问题中经常使用的方法,思想比较简单,但是具体实现与求解细节对工程人员来说比较复杂,如需了解SVM的入门知识和中级进阶可点此下载.本文从应用的角度出发,使用Libs ...
- VQ结合SVM分类方法
今天整理资料时,发现了在学校时做的这个实验,当时整个过程过重偏向依赖分类器方面,而又很难对分类器性能进行一定程度的改良,所以最后没有选用这个方案,估计以后也不会接触这类机器学习的东西了,希望它对刚入门 ...
随机推荐
- 项目 solrcloud / zookeeper 搭建
财经道网站搜索引擎,数据快速检索,数据集群 功能描述:使用solr为项目数据库表p2p,银行理财,基金,贷款,信托,保险等建立数据索引,实现数据的导入,增量索引.实现检索建议和数据的快速查找.使用zo ...
- 硬件GPIO,UART,I2C,SPI电路图
- Silverlight或WPF动态绑定图片路径问题,不用Converter完美解决
关于Silverlight或WPF动态绑定图片路径问题,不用Converter完美解决, 可想,一个固定的字符串MS都能找到,按常理动态绑定也应该没问题的,只需在前面标记它是一个Path类型的值它就能 ...
- vertical-align和text-align
vertical-align只适用于内联元素. 垂直对齐:vertical-align属性(转) 行高与单行纯文字的垂直居中,而如果行内含有图片和文字,在浏览器内浏览时,读者可以发现文字和图片在垂直方 ...
- hi~大家好,特地出来解释下最近为啥都不更新了!
总结一句话就是!因为我有宝宝啦~加上项目赶得不要不要的公司原因加上个人原因只能在家养胎啦,对象也是程序猿哦~不过是后端程序猿哈哈哈. 我打算开公众号(百撕可乐)啦,和博客圆的名字一样,毕竟用了这么多年 ...
- 2018.10.19 NOIP训练 桌子(快速幂优化dp)
传送门 勉强算一道dp好题. 显然第kkk列和第k+nk+nk+n列放的棋子数是相同的. 因此只需要统计出前nnn列的选法数. 对于前mmm%nnn列,一共有(m−1)/n+1(m-1)/n+1(m− ...
- php读取用友u8客户档案
include('../common/conn.php'); $list=[]; $sql="SELECT a.cCusCode,a.cCusName,b.cCCName,a.cCusDep ...
- http://localhost:8080/hello?wsdl
<definitions xmlns:wsu="http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-u ...
- Sublime必用快捷键[私人]
最近一年前端开发都是用sublime这款编辑器, 相对于webStorm强大而启动慢.editplus快启动而功能弱, sublime恰好在两者之间:而且其指令行安装.更新.卸载插件比eclipse之 ...
- SoC开发板设置网口IP为固定IP
vi /etc/network/interfaces 编辑这个文件 #iface eth0 inet dhcp 找到修改这个,前面加# iface eth0 inet static 改为静态分配i ...