周老师的书,对神经网络写了一个小的Demo

是最简单的神经网络,只有一层的隐藏层。

这次练习依旧是对西瓜的好坏进行预测。

主要分了以下几个步骤

1、数据预处理

对西瓜的不同特性进行数学编码表示(0~1),我是直接编了对应数字。含糖量已经是一个0~1之间的数,所以就没有进行处理

青绿  1

乌黑 0.5

浅白  0

蜷缩  1

稍蜷 0.5

硬挺  0

浊响  1

沉闷 0.5

清脆  0

清晰  1

稍糊 0.5

模糊  0

凹陷  1

稍凹 0.5

平坦  0

硬滑  1

软黏  0

2、训练集和检测集

  1. package BP;
  2. public class TrainData {
  3. double[][] traindata;
  4. double[][] traindataoutput;
  5. double[][] testdata;
  6. double[][] testdataoutput;
  7. public TrainData(){
  8. traindata = new double[][]{
  9. new double[]{1,1,1,1,1,1,0.697,0.460},
  10. new double[]{0.5,1,0.5,1,1,1,0.774,0.376},
  11. new double[]{0.5,1,1,1,1,1,0.634,0.264},
  12. //new double[]{1,1,0.5,1,1,1,0.608,0.318,1},
  13. //new double[]{0,1,1,1,1,1,0.556,0.215,1},
  14. new double[]{1,0.5,1,1,0.5,0,0.403,0.237},
  15. new double[]{0.5,0.5,1,0.5,0.5,0,0.481,0.149},
  16. //new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211,1},
  17. //new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091,0},
  18. //new double[]{1,0,0,1,0,0,0.243,0.267,0},
  19. //new double[]{0,0,0,0,0,1,0.245,0.057,0},
  20. //new double[]{0,1,1,0,0,0,0.343,0.099,0},
  21. new double[]{1,0.5,1,0.5,1,1,0.639,0.161},
  22. new double[]{0,0.5,0,0.5,1,1,0.657,0.198},
  23. new double[]{0.5,0.5,1,1,0.5,0,0.360,0.370},
  24. new double[]{0,1,1,0,0,1,0.593,0.042},
  25. new double[]{1,1,0.5,0.5,0.5,1,0.719,0.103}
  26. };
  27. traindataoutput = new double[][]{
  28. new double[]{1},
  29. new double[]{1},
  30. new double[]{1},
  31. new double[]{1},
  32. new double[]{1},
  33. new double[]{0},
  34. new double[]{0},
  35. new double[]{0},
  36. new double[]{0},
  37. new double[]{0},
  38. };
  39. testdata = new double[][]{
  40. new double[]{1,1,0.5,1,1,1,0.608,0.318},
  41. new double[]{0,1,1,1,1,1,0.556,0.215},
  42. new double[]{0.5,0.5,1,1,0.5,1,0.437,0.211},
  43. new double[]{0.5,0.5,0.5,0.5,0.5,1,0.666,0.091},
  44. new double[]{1,0,0,1,0,0,0.243,0.267},
  45. new double[]{0,0,0,0,0,1,0.245,0.057},
  46. new double[]{0,1,1,0,0,0,0.343,0.099},
  47. };
  48. testdataoutput = new double[][]{
  49. new double[]{1},
  50. new double[]{1},
  51. new double[]{1},
  52. new double[]{0},
  53. new double[]{0},
  54. new double[]{0},
  55. new double[]{0},
  56. };
  57. }
  58. public static void main(String[] args){
  59. TrainData t = new TrainData();
  60. for(int i=0;i<t.traindata.length;i++){
  61. for(int j=0;j<9;j++)
  62. System.out.print(t.traindata[i][j]+ " ");
  63. System.out.println();
  64. }
  65. }
  66. }

3、BP主函数

  1. package BP;
  2. import java.util.Random;
  3. public class BP {
  4. int innum;
  5. int hiddennum;
  6. int outnum;
  7. //输入、隐藏、输出层
  8. public double[] input;
  9. public double[] hidden;
  10. //output为本神经网络计算出的输出值
  11. public double[] output;
  12. //realoutput为训练网络时,用户提供的真的输出值
  13. public double[] realoutput;
  14. //v[i,j]表示输入层i到隐层j  w[i,j]表示隐层i到输出层j
  15. public double[][] v;
  16. public double[][] w;
  17. //beta为隐层的阈值,afa为输出层阈值
  18. public double[] beta;
  19. public double[] afa;
  20. //学习率
  21. public double eta;
  22. //步长
  23. public double momentum;
  24. public final Random random;
  25. public BP(int inputnum,int hiddennum,int outputnum,double learningrate){
  26. innum = inputnum;
  27. this.hiddennum = hiddennum;
  28. outnum = outputnum;
  29. input = new double[inputnum + 1];
  30. hidden = new double[hiddennum + 1];
  31. output = new double[outputnum + 1];
  32. realoutput = new double[outputnum + 1];
  33. v = new double[inputnum + 1][hiddennum + 1];
  34. w = new double[hiddennum + 1][outputnum + 1];
  35. beta = new double[outputnum + 1];
  36. afa = new double[hiddennum + 1];
  37. for(int i=0;i<outputnum;i++)
  38. beta[i] = 0.0;
  39. for(int i=0;i<hiddennum;i++)
  40. afa[i] = 0.0;
  41. eta = learningrate;
  42. //随机数对结果影响较大
  43. random = new Random(19950326);
  44. randomizeWeights(w);
  45. randomizeWeights(v);
  46. }
  47. public void testData(double[] in){
  48. input = in;
  49. getNetOutput();
  50. }
  51. //只对本题目有用,output>0.5时为好西瓜,output<0.5时为坏西瓜
  52. public int predict(double[] in){
  53. testData(in);
  54. if(output[0]>0.5)
  55. return 1;
  56. else
  57. return 0;
  58. }
  59. //获得在test集上的正确率
  60. public double getAccuracy(double[][] in,double[][] out){
  61. int rightans = 0,wrongans = 0;
  62. for(int i=0;i<in.length;i++){
  63. if(predict(in[i])==(out[i][0])){
  64. //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
  65. rightans++;
  66. }else{
  67. //System.out.println("预测结果:"+predict(in[i])+" 实际结果为:"+out[i][0]);
  68. wrongans++;
  69. }
  70. }
  71. System.out.println("对:"+rightans+" 错:"+wrongans);
  72. return (double)rightans/(double)(rightans+wrongans);
  73. }
  74. //times为进行几轮训练
  75. public void train(int times){
  76. TrainData t = new TrainData();
  77. double wu = 0.0,acc = 0.0;
  78. int n = t.traindata.length;
  79. for(int i=0;i<times;i++){
  80. wu = 0.0;
  81. for(int j=0;j<n;j++){
  82. traindata(t.traindata[j],t.traindataoutput[j]);
  83. wu += getDeviation();
  84. }
  85. wu = wu/((double)n);
  86. System.out.println("第"+i+"轮训练:"+wu);
  87. acc = getAccuracy(t.testdata,t.testdataoutput);
  88. System.out.println("预测正确率为: "+acc);
  89. }
  90. }
  91. //对一个input输入进行训练
  92. public void traindata(double[] in,double[] out){
  93. input = in;
  94. realoutput = out;
  95. getNetOutput();
  96. adjustParameter();
  97. }
  98. //获得误差E
  99. public double getDeviation(){
  100. double e = 0.0;
  101. for(int i=0;i<outnum;i++)
  102. e += (output[i] - realoutput[i])*(output[i] - realoutput[i]);
  103. e *= 0.5;
  104. return e;
  105. }
  106. //调整权值
  107. public void adjustParameter(){
  108. double g[],e = 0.0;
  109. g = new double[outnum];
  110. int i,j;
  111. for(i=0;i<outnum;i++){
  112. g[i] = output[i]*(1-output[i])*(realoutput[i]-output[i]);
  113. beta[i] -= eta * g[i];
  114. for(j=0;j<hiddennum;j++){
  115. w[j][i] += eta * g[i] * hidden[j];
  116. }
  117. }
  118. for(i=0;i<hiddennum;i++){
  119. e = 0.0;
  120. for(j=0;j<outnum;j++)
  121. e += g[j]*w[i][j];
  122. e = hidden[i]*(1-hidden[i])*e;
  123. afa[i] -= eta * e;
  124. for(j=0;j<innum;j++)
  125. v[j][i] += eta * e * input[j];
  126. }
  127. }
  128. //获得output
  129. public void getNetOutput(){
  130. int i,j;
  131. double tmp=0.0;
  132. for(i=0;i<hiddennum;i++){
  133. tmp = 0.0;
  134. for(j=0;j<innum;j++)
  135. tmp += v[j][i]*input[j];
  136. hidden[i] = sigmoid(tmp-afa[i]);
  137. }
  138. for(i=0;i<outnum;i++){
  139. tmp = 0.0;
  140. for(j=0;j<hiddennum;j++)
  141. tmp += w[j][i]*hidden[j];
  142. output[i] = sigmoid(tmp-beta[i]);
  143. }
  144. }
  145. //对权值矩阵w、v进行初始随机化
  146. private void randomizeWeights(double[][] matrix) {
  147. for (int i = 0, len = matrix.length; i != len; i++)
  148. for (int j = 0, len2 = matrix[i].length; j != len2; j++) {
  149. double real = random.nextDouble();
  150. matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;
  151. }
  152. }
  153. public void debug(){
  154. System.out.println("========begin=======");
  155. for(int i=0;i<innum;i++){
  156. for(int j=0;j<hiddennum;j++)
  157. System.out.print(v[i][j]+" ");
  158. System.out.println();
  159. }
  160. System.out.println();
  161. for(int i=0;i<hiddennum;i++){
  162. for(int j=0;j<outnum;j++)
  163. System.out.print(w[i][j]+" ");
  164. System.out.println();
  165. }
  166. System.out.println("========end=======");
  167. }
  168. public double sigmoid(double z){
  169. double s = 0.0;
  170. s = 1d/(1d + Math.exp(-z));
  171. return s;
  172. }
  173. public static void main(String[] args){
  174. BP bp = new BP(8,10,1,0.1);
  175. bp.train(50);
  176. }
  177. }

我要说的:

就结果来说,在验证集上的正确率可达到85%,当然很大程度上取决于BP初始化时random函数的种子。运气好的时候甚至能达到100%的正确率,运气不好的时候只有40%多,跟随便乱猜没什么区别。

想问大神。。。只能采用这种随机算法来找到一个最合适的ramdom种子值嘛?能不能用遗传这样的开放式算法进行搜索来找到最合适的随机值(我觉得随机的种子和随机结果并没有什么直接的关联,所以不知道能不能用遗传算法之列。。。)

机器学习 demo分西瓜的更多相关文章

  1. 分西瓜(DFS)

    描述今天是阴历七月初五,acm队员zb的生日.zb正在和C小加.never在武汉集训.他想给这两位兄弟买点什么庆祝生日,经过调查,zb发现C小加和never都很喜欢吃西瓜,而且一吃就是一堆的那种,zb ...

  2. LASSO回归与L1正则化 西瓜书

    LASSO回归与L1正则化 西瓜书 2018年04月23日 19:29:57 BIT_666 阅读数 2968更多 分类专栏: 机器学习 机器学习数学原理 西瓜书   版权声明:本文为博主原创文章,遵 ...

  3. 131.003 数据预处理之Dummy Variable & One-Hot Encoding

    @(131 - Machine Learning | 机器学习) Demo 直观来说就是有多少个状态就有多少比特,而且只有一个比特为1,其他全为0的一种码制 {sex:{male, female}}​ ...

  4. CUDA程序设计(一)

    为什么需要GPU 几年前我启动并主导了一个项目,当时还在谷歌,这个项目叫谷歌大脑.该项目利用谷歌的计算基础设施来构建神经网络. 规模大概比之前的神经网络扩大了一百倍,我们的方法是用约一千台电脑.这确实 ...

  5. ios基础篇(二十五)—— Animation动画(UIView、CoreAnimation)

    Animation主要分为两类: 1.UIView属性动画 2.CoreAnimation动画 一.UIView属性动画 UIKit直接将动画集成到UIView类中,实现简单动画的创建过程.UIVie ...

  6. NY 325 zb的生日

    假设所有西瓜重 Asum,所求的是用 Asum / 2 的背包装,最多装下多少. 刚开始用贪心作的,WA.后来用01背包,结果TLE,数据太大.原来用的是深搜! dfs(int sum, int i) ...

  7. backbone.Router History源码笔记

    Backbone.History和Backbone.Router history和router都是控制路由的,做一个单页应用,要控制前进后退,就可以用到他们了. History类用于监听URL的变化, ...

  8. spring springMVC mybatis 集成

    最近闲来无事,整理了一下spring springMVC mybatis 集成,关于这个话题在园子里已经有很多人写过了,我主要是想提供一个完整的demo,涵盖crud,事物控制等. 整个demo分三个 ...

  9. iOS百度推送的基本使用

    一.iOS证书指导 在 iOS App 中加入消息推送功能时,必须要在 Apple 的开发者中心网站上申请推送证书,每一个 App 需要申请两个证书,一个在开发测试环境下使用,另一个用于上线到 App ...

随机推荐

  1. js获取checkbox复选框获取选中的选项

    js获取checkbox复选框获取选中的选项 分享下javascript获取checkbox 复选框获取选中的选项的方法. 有关javascript 获取checkbox复选框的实例数不胜数.js实现 ...

  2. IOC的实现原理—反射与工厂模式的结合

    反射机制概念   我们考虑一个场景,如果我们在程序运行时,一个对象想要检视自己所拥有的成员属性,该如何操作?再考虑另一个场景,如果我们想要在运行期获得某个类的Class信息如它的属性.构造方法.一般方 ...

  3. redis-cli 连接远程服务器

    # redis-cli -h 10.11.09.10 -p 6379 #注意空格

  4. 【Java】Java复习笔记-第二部分

    类和对象 类:主观抽象,是对象的模板,可以实例化对象 习惯上类的定义格式: package xxx; import xxx; public class Xxxx { 属性 ······; 构造器 ·· ...

  5. muduo源码分析:组成结构

    muduo整体由许多类组成: 这些类之间有一些依赖关系,如下:

  6. 安装C/C++交叉编译环境

    转:http://blog.csdn.net/nokiaguy/article/details/8509739 X86架构的CPU采用的是复杂指令集(Complex Instruction Set C ...

  7. Delphi中的三目运算函数有哪些?(XE10.2+WIN764)

    相关资料:https://www.cnblogs.com/rogge7/p/6078903.html 问题现象:在做一个判断时突然想到了C++的三目运算,就在想Delphi中一共有几个? 问题处理: ...

  8. 逻辑回归(LR)和支持向量机(SVM)的区别和联系

    1. 前言 在机器学习的分类问题领域中,有两个平分秋色的算法,就是逻辑回归和支持向量机,这两个算法个有千秋,在不同的问题中有不同的表现效果,下面我们就对它们的区别和联系做一个简单的总结. 2. LR和 ...

  9. python parse xml using DOM

    demo: import xml.dom.minidom dom=xml.dom.minidom.parse('sample.xml')root = dom.documentElementcc=dom ...

  10. python matplotlib 画图

    import numpy as np import matplotlib.pyplot as plt from pylab import * numpy 常用来组织源数据: 使用 plot 函数直接绘 ...