从昨天晚上,到今天上午12点半左右吧,一直在调这个代码。最开始训练的时候,老是说loss:nan

查了资料,因为是如果损失函数使用交叉熵,如果预测值为0或负数,求log的时候会出错。需要对预测结果进行约束。

有一种约束方法是:y_predict=max(y,(0,1e-18])。也就是将小于0的数值随机分配为(0,1e-18]中的某个数。这样做好像不太合适。

还有一种方法是使用sigmoid作为激活函数。我这样改正了之后仍然没有效果。

后来我把数据集中的图片打开看了一下才发现,它跟mnist不一样,是彩色的,matplot告诉我说是24位彩色,但是颜色值最大是16

把颜色规范化之后,在训练,准确度在0.93左右。

首先是线性回归类:

import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape,1.,1.)
#weiInit=tf.constant(10.,shape=shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.001).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm]
self.lables=self.lables[perm]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end],self.lables[start:end]
def one_hot(self,labels,class_num):
b=tf.one_hot(labels,class_num,1,0)
with tf.Session() as sess:
return sess.run(b) def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
#self.lables=self.one_hot(y_train.reshape(1,-1).tolist()[0],10)
self.lables = y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(3000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("y is", sess.run(self.y, feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("correct_mat",sess.run(self.correct_mat,feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))

这里做的就是普通的线性回归:y_predict = w*x+bias ,用交叉熵做损失函数

然后是我的运行类:

from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow as tf
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
# import tensorflow.examples.tutorials.mnist.input_data as input_data
# mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) def one_hot(labels, class_num):
b = tf.one_hot(labels, class_num, 1, 0)
with tf.Session() as sess:
return sess.run(b)
def normal(x):
return (x-8)/16
if __name__=='__main__': # x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
# linear = mlm(len(x_train[1]))
# linear.train(x_train,y_train,x_test,y_test)
digit=load_digits()
x_train,x_test,y_train,y_test=train_test_split(digit.data,digit.target,test_size=0.5)
y_lrm_train=one_hot(y_train.reshape(1,-1).tolist()[0],10)
y_lrm_test=one_hot(y_test.reshape(1,-1).tolist()[0],10)
x_train=normal(x_train)
x_test=normal(x_test)
linear=mlm(x_train.shape[1])
linear.train(x_train,y_lrm_train,x_test,y_lrm_test)

注释掉的是在mnist训练集上的训练。可以看到我对digit数据集做了两个预处理:

将target转换成one-hot:

    b = tf.one_hot(labels, class_num, 1, 0)
with tf.Session() as sess:
return sess.run(b)

调用tensorflow的one_hot方法。传入的参数是(原标签,分类数,active值,negtive值)

以这个数据集为例,每张图片是8*8的小图片,也就是有64个特征值,标签是图片对应的数字,即0 1 2 3.。。。

所以one_hot(y,10,1,0)之后:

1===》[0 1 0 0 0 0 0 0 0 0]

2===》[0 0 1 0 0 0 0 0 0 0]

。。。。。

这样我们就可以计算交叉熵了。对应分类问题,一般都是将结果表示为one-hot的向量。

第二个预处理是去掉颜色信息:

return (x-8)/16

直接简单粗暴: [x-(max-min)]/max

由于我知道最小值是0,最大值是16,所以直接套数上去了。

最后是部分运行结果:

setp 0,test_loss 567.181641
setp 1,test_loss 514.701660
setp 2,test_loss 476.468811
setp 3,test_loss 465.000031
setp 4,test_loss 495.855927
setp 5,test_loss 439.806610
setp 6,test_loss 381.180908

setp 2995,test_loss 16.674290
setp 2996,test_loss 17.192608
setp 2997,test_loss 14.386111
setp 2998,test_loss 12.961364
setp 2999,test_loss 14.622895
0.93770856

可以看到后面的训练损失值一值在抖动。不知道是为什么。

使用线性回归识别sklearn中的手写数字digit的更多相关文章

  1. 使用TensorFlow的卷积神经网络识别自己的单个手写数字,填坑总结

    折腾了几天,爬了大大小小若干的坑,特记录如下.代码在最后面. 环境: Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU 方法: 先用MNI ...

  2. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  3. PyTorch基础——使用卷积神经网络识别手写数字

    一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...

  4. 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)

    @ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...

  5. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  6. TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题

    出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  9. KNN识别手写数字

    一.问题描述 手写数字被存储在EXCEL表格中,行表示一个数字的标签和该数字的像素值,有多少行就有多少个样本. 一共42000个样本 二.KNN KNN最邻近规则,主要应用领域是对未知事物的识别,即判 ...

随机推荐

  1. nginx init 官方启动脚本

    #!/bin/sh # # nginx - this script starts and stops the nginx daemon # # chkconfig: - 85 15 # descrip ...

  2. HDU 2227 Find the nondecreasing subsequences (数状数组)

    Find the nondecreasing subsequences Time Limit: 10000/5000 MS (Java/Others)    Memory Limit: 32768/3 ...

  3. hadoop遇到的问题及处理

    1:杀掉hadoop作业 列出作业 ./hadoop job -list 杀掉 ./hadoop job -kill job_id 1:某些节点出现running asprocess XXX. Sto ...

  4. RabbitMQ.Client API (.NET)中文文档

    主要的名称空间,接口和类 核心API中定义接口和类 RabbitMQ.Client 名称空间: 1 using RabbitMQ.Client; 核心API接口和类 IModel :表示一个AMQP ...

  5. 免费的UI素材准备

    UI素材准备 UI也是一个专业性比较强的一个活啊,不过还好我有强大的百度,强大的百度有各种强大的网站,下面介绍一些UI常用的网站1.阿里巴巴矢量图标库 http://www.iconfont.cn/p ...

  6. 整合Solr到Tomcat服务器,并配置IK分词

    好久没有接触新东西了,最新开始熟悉solr,实例展示单机环境solr整合. 整合方案一 1.下载Tomcat与solr并解压 Tomcat解压后磁盘路径为D:\program files\Tomcat ...

  7. Oracle 12C -- truncate的级联操作

    在之前的版本中,存在外键约束时,无法直接truncate父表.在12C中,对truncate操作添加了级联操作特性. 前提是创建外键约束时,使用了"on delete casacde&quo ...

  8. [转]利用Docker构建开发环境

    利用Docker构建开发环境 Posted by  makewonder on 2014 年 4 月 2 日   最近接触PAAS相关的知识,在研发过程中开始使用Docker搭建了自己完整的开发环境, ...

  9. Asp.net2.0之自定义控件ImageButton

    控件模仿winform中的button,可以支持图片和文字.可以选择执行服务器端程序还是客户端程序,还有一些简单的设置. 不足的是不支持样式,下次希望可以写一个工具条. 以下就是代码 以下为引用的内容 ...

  10. springboot 多模块 maven 项目构建jar 文件配置

    最近在写 springboot 项目时,需要使用多模块,遇到了许多问题. 1 如果程序使用了 java8 的一些特性,springboot 默认构建工具不支持.需要修改配置 ... </buil ...