首先是DATA类

import java.awt.print.Printable;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner; public class Data {
public Map<List<Double>, Integer> getTrainData() {
Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>(); try {
Scanner in=new Scanner(new File("G://download//testSet.txt"));
while(in.hasNextLine())
{
String str =in.nextLine();
String []strs=str.trim().split("\t");
List<Double> pointTmp=new ArrayList<>();
for(int i=0;i<strs.length-1;i++)
pointTmp.add(Double.parseDouble(strs[i]));
data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
}
} catch (FileNotFoundException e) {
// TODO: handle exception
e.printStackTrace();
} return data;
} public static void main(String[] args)
{
Data data=new Data();
data.getTrainData();
}
}

  SVM类:

import java.awt.print.Printable;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream.GetField;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry; public class SVM {
private List<ArrayList<Double>> trainData;
private List<Integer> labelTrainData;
private double sigma;
private double C;
private List<Double> alpha;
private double b;
private List<Double> E;
private int N;
private int dim;
private double tol;
private double eta;
private double eps;
private double eps2; public boolean satisfyKkt(int id)
{
double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
if(Math.abs(this.alpha.get(id))<=this.eps)
{
if(ypgx-1<-this.tol) return false;
}
else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
{
if(ypgx-1>this.tol) return false;
}
else {
if(Math.abs(ypgx-1)>this.tol) return false;
}
return true;
} public void updateE() { for(int i=0;i<this.N;i++)
{
double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
this.E.set(i, Ei);
}
} public double kernelLinear(List<Double> X,List<Double> Y) {
//linear kernel function
int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=X.get(i)*Y.get(i);
return s;
} public double kernelRBF(List<Double> X,List<Double> Y)
{
//gauss kernel function int len=Y.size();
double s=0;
for(int i=0;i<len;i++)
s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
return s;
} public double getGx(List<Double> X)
{
//calculate wx+b value
double s=0;
for(int i=0;i<this.N;i++)
{
//for debug
double debug1=kernelRBF(X, this.trainData.get(i));
double debug2=this.alpha.get(i); s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
}
s+=this.b;
return s;
} public int update(int x1,int x2)
{
double low=0;
double high=0;
if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
{
low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
}
else
{
low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
}
double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
double newAlpha1=0; if(newAlpha2>high) newAlpha2=high;
else if(newAlpha2<low) newAlpha2=low;
newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2); if(Math.abs(newAlpha1)<=this.eps)
newAlpha1=0;
if(Math.abs(newAlpha2)<=this.eps)
newAlpha2=0;
if(Math.abs(newAlpha1-this.C)<=this.eps)
newAlpha1=this.C;
if(Math.abs(newAlpha2-this.C)<=this.eps)
newAlpha2=this.C;
if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
return 0;
if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
return 0; double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b; if(newAlpha1>0&&newAlpha1<this.C)
this.b=b1;
else if(newAlpha2>0&&newAlpha2<this.C)
this.b=b2;
else
this.b=(b1+b2)/2; this.alpha.set(x1,newAlpha1);
this.alpha.set(x2,newAlpha2);
updateE();
return 1;
}
public int selectAlpha2(int x1) { int x2=-1;
double maxDiff=-1;
//first select x2 from 0<a<c to max(E(x1)-E(x2)) for(int i=0;i<this.N;++i)
{
if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
double diff=Math.abs(this.E.get(x1)-this.E.get(i));
if(diff>maxDiff)
{
maxDiff=diff;
x2=i;
}
} //second calculate eta (eta!=0)
if(x2!=-1)
{
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
if(eta!=0) return x2;
} //third if cannot find in the whole train set
for(int i=0;i<this.N;i++)
{
if(i==x1) continue;
this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
if(Math.abs(this.eta)>this.eps) return i;
}
return -1; } public void SMO() {
//to solve alpha
int numChanged=0;
int cnt=0;
while(true)
{
cnt++;
System.out.println(cnt); numChanged=0;
for(int x1=0;x1<this.N;++x1)
{
if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
numChanged+=update(x1, x2);
}
}
if(numChanged==0)
{
for(int x1=0;x1<this.N;++x1)
{
if(!satisfyKkt(x1))
{
int x2=selectAlpha2(x1);
if(x2==-1) continue;
update(x1, x2);
numChanged++;
}
}
}
if(numChanged==0)
break;
}
} public SVM() {
//load train data Data data=new Data();
Map<List<Double>, Integer> Datas=data.getTrainData();
int totalData=Datas.size();
this.trainData=new ArrayList<ArrayList<Double>>();
this.labelTrainData=new ArrayList<Integer>();
this.alpha=new ArrayList<Double>();
this.E=new ArrayList<Double>(); int i=0;
for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
{
this.trainData.add((ArrayList<Double>) entry.getKey());
this.labelTrainData.add(entry.getValue());
this.alpha.add(0.0);
this.E.add(0.0-this.labelTrainData.get(i));
i++;
}
this.N=this.labelTrainData.size();
this.dim=this.trainData.get(0).size(); this.sigma=12;//sigma=1
this.C=0.5;//c=6
this.b=0.0;
this.tol=0.001;
this.eta=0;
this.eps=0.0000001;
this.eps2=0.00001;
} public double getB() {
//get b value
return this.b;
}
public double[] getLinearW() {
double []w=new double[this.N];
for(int i=0;i<this.N;i++)
{
for(int j=0;j<this.dim;j++)
{
w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
}
}
return w;
} public int predict(List<Double> x)
{
int ans=1;
double sum=0;
for(int i=0;i<this.N;i++)
{
sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
}
sum+=b;
if(sum>0)
ans=1;
else
ans=-1; return ans;
}
public static void main(String[] args) throws FileNotFoundException { SVM s=new SVM();
s.SMO();
PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
for(int i=0;i<s.N;i++)
{
out.write((s.trainData.get(i).get(0)).toString());
out.write("\t");
out.write((s.trainData.get(i).get(1)).toString());
out.write("\t");
out.write(Integer.toString(s.predict(s.trainData.get(i))));
out.write("\n");
}
out.close();
//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
double w[]=s.getLinearW();
System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
} }

  

用线性核函数实现的SVM的到的分类结果

画图,是用python代码

from numpy import *
import matplotlib
import matplotlib.pyplot as plt
import numpy as np with open("g://download/myresult.txt") as f1:
data=f1.readlines(); plt.figure(figsize=(8, 5), dpi=80)
axes = plt.subplot(111)
type1_x = []
type1_y = []
type2_x = []
type2_y = []
for line in data:
x=line.strip().split('\t');
x1=float(x[0])
x2=float(x[1])
x3=int(x[2]) if x3==1:
type1_x.append(x1)
type1_y.append(x2)
else:
type2_x.append(x1)
type2_y.append(x2) type1 = axes.scatter(type1_x, type1_y,s=40, c='red' )
type2 = axes.scatter(type2_x, type2_y, s=40, c='green') W1 = 0.8148005405344305
W2 = -0.27263471796762484
B = -3.8392586254518437
x = np.linspace(-4,10,200)
y = (-W1/W2)*x+(-B/W2)
axes.plot(x,y,'b',lw=3) plt.xlabel('x1')
plt.ylabel('x2') axes.legend((type1, type2), ('0', '1'),loc=1)
plt.show() #0.8148005405344305 -0.27263471796762484 -3.8392586254518437

  用高斯核,当C=6,sigma=1时候

高斯核,当c=0.5,sigma=1时候

当C=0.5,sigma=12时候

说明C的大小和sigma的大小对高斯核影响是很大的

sigma是高斯核函数的参数

自己实现的SVM源码的更多相关文章

  1. EasyPR源码剖析(1):概述

    EasyPR(Easy to do Plate Recognition)是本人在opencv学习过程中接触的一个开源的中文车牌识别系统,项目Git地址为https://github.com/liuru ...

  2. Mahout源码目录说明&&算法集

    Mahout源码目录说明 mahout项目是由多个子项目组成的,各子项目分别位于源码的不同目录下,下面对mahout的组成进行介绍: 1.mahout-core:核心程序模块,位于/core目录下: ...

  3. 近200篇机器学习&深度学习资料分享(含各种文档,视频,源码等)(1)

    原文:http://developer.51cto.com/art/201501/464174.htm 编者按:本文收集了百来篇关于机器学习和深度学习的资料,含各种文档,视频,源码等.而且原文也会不定 ...

  4. Ubentu编译Android源码(AOSP)

    前言: 一直想要编译一下Android 源码,之前去google 看,下载要下载repo. 当时很懵逼,repo 是个什么?(repo 是一个python 脚本,因为Android 源码git 仓库太 ...

  5. Android FrameWork 学习之Android 系统源码调试

    这是很久以前访问掘金的时候 无意间看到的一个关于Android的文章,作者更细心,分阶段的将学习步骤记录在自己博客中,我觉得很有用,想作为分享同时也是留下自己知识的一些欠缺收藏起来,今后做项目的时候会 ...

  6. Python的开源人脸识别库:离线识别率高达99.38%(附源码)

    Python的开源人脸识别库:离线识别率高达99.38%(附源码) 转https://cloud.tencent.com/developer/article/1359073   11.11 智慧上云 ...

  7. GWO(灰狼优化)算法MATLAB源码逐行中文注解(转载)

    以优化SVM算法的参数c和g为例,对GWO算法MATLAB源码进行了逐行中文注解. tic % 计时器 %% 清空环境变量 close all clear clc format compact %% ...

  8. FaceNet pre-trained模型以及FaceNet源码使用方法和讲解

    Pre-trained models Model name LFW accuracy Training dataset Architecture 20180408-102900 0.9905 CASI ...

  9. KNN算法介绍及源码实现

    一.KNN算法介绍 邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一.所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它 ...

随机推荐

  1. 【刷题】洛谷 P4716 【模板】最小树形图

    题目背景 这是一道模板题. 题目描述 给定包含 \(n\) 个结点, \(m\) 条有向边的一个图.试求一棵以结点 \(r\) 为根的最小树形图,并输出最小树形图每条边的权值之和,如果没有以 \(r\ ...

  2. 【BZOJ4129】Haruna’s Breakfast(树上莫队)

    [BZOJ4129]Haruna's Breakfast(树上莫队) 题面 BZOJ Description Haruna每天都会给提督做早餐! 这天她发现早饭的食材被调皮的 Shimakaze放到了 ...

  3. Codeforces Round #405 (rated, Div. 2, based on VK Cup 2017 Round 1)

    A 模拟 B 发现对于每个连通块,只有为完全图才成立,然后就dfs C 构造 想了20分钟才会,一开始想偏了,以为要利用相邻NO YES的关系再枚举,其实不难.. 考虑对于顺序枚举每一个NO/YES, ...

  4. 扔几道sb题

    1.给定一个长度为N的数列,A1, A2, ... AN,如果其中一段连续的子序列Ai, Ai+1, ... Aj(i <= j)之和是K的倍数,我们就称这个区间[i, j]是K倍区间. 你能求 ...

  5. 【省选水题集Day1】一起来AK水题吧! 题解(更新到B)

    题目:http://www.cnblogs.com/ljc20020730/p/6937936.html 水题A:[AHOI2001]质数和分解 安徽省选OI原题!简单Dp. 一看就是完全背包求方案数 ...

  6. 《剑指offer》— JavaScript(14)链表中倒数第k个结点

    链表中倒数第k个结点 题目描述 输入一个链表,输出该链表中倒数第k个结点. 思路 两个指针,先让第一个指针和第二个指针都指向头结点,然后再让第一个指正走(k-1)步,到达第k个节点: 然后两个指针同时 ...

  7. NYOJ--520

    最大素因子 原题链接:http://acm.nyist.net/JudgeOnline/problem.php?pid=520 分析:先筛素数,同时记录下素数的序号,然后质因数分解. #include ...

  8. pandans导出Excel并将数据保存到不同的Sheet表中

    数据存在mongodb中,按照类别导出到Excel文件,问题是想把同一类的数据放到一个sheet表中,最后只导出到一个excel文件中# coding=utf-8import pandas as pd ...

  9. (转)linux下vi命令修改文件及保存的使用方法

    进入vi的命令         vi filename :打开或新建文件,并将光标置于第一行首    vi n filename :打开文件,并将光标置于第n行首    vi filename :打开 ...

  10. JS笔记-强化版2

    1.DOM:   DOM : Document Object Model 文档对象模型 文档:html页面 文档对象:页面中元素 文档对象模型:定义 为了能够让程序(js)去操作页面中的元素   DO ...