从昨天晚上,到今天上午12点半左右吧,一直在调这个代码。最开始训练的时候,老是说loss:nan

查了资料,因为是如果损失函数使用交叉熵,如果预测值为0或负数,求log的时候会出错。需要对预测结果进行约束。

有一种约束方法是:y_predict=max(y,(0,1e-18])。也就是将小于0的数值随机分配为(0,1e-18]中的某个数。这样做好像不太合适。

还有一种方法是使用sigmoid作为激活函数。我这样改正了之后仍然没有效果。

后来我把数据集中的图片打开看了一下才发现,它跟mnist不一样,是彩色的,matplot告诉我说是24位彩色,但是颜色值最大是16

把颜色规范化之后,在训练,准确度在0.93左右。

首先是线性回归类:

import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape,1.,1.)
#weiInit=tf.constant(10.,shape=shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.001).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm]
self.lables=self.lables[perm]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end],self.lables[start:end]
def one_hot(self,labels,class_num):
b=tf.one_hot(labels,class_num,1,0)
with tf.Session() as sess:
return sess.run(b) def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
#self.lables=self.one_hot(y_train.reshape(1,-1).tolist()[0],10)
self.lables = y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(3000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("y is", sess.run(self.y, feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("correct_mat",sess.run(self.correct_mat,feed_dict={self.x: batch[0],self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))

这里做的就是普通的线性回归:y_predict = w*x+bias ,用交叉熵做损失函数

然后是我的运行类:

from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow as tf
from sklearn.cross_validation import train_test_split
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
# import tensorflow.examples.tutorials.mnist.input_data as input_data
# mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) def one_hot(labels, class_num):
b = tf.one_hot(labels, class_num, 1, 0)
with tf.Session() as sess:
return sess.run(b)
def normal(x):
return (x-8)/16
if __name__=='__main__': # x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
# linear = mlm(len(x_train[1]))
# linear.train(x_train,y_train,x_test,y_test)
digit=load_digits()
x_train,x_test,y_train,y_test=train_test_split(digit.data,digit.target,test_size=0.5)
y_lrm_train=one_hot(y_train.reshape(1,-1).tolist()[0],10)
y_lrm_test=one_hot(y_test.reshape(1,-1).tolist()[0],10)
x_train=normal(x_train)
x_test=normal(x_test)
linear=mlm(x_train.shape[1])
linear.train(x_train,y_lrm_train,x_test,y_lrm_test)

注释掉的是在mnist训练集上的训练。可以看到我对digit数据集做了两个预处理:

将target转换成one-hot:

    b = tf.one_hot(labels, class_num, 1, 0)
with tf.Session() as sess:
return sess.run(b)

调用tensorflow的one_hot方法。传入的参数是(原标签,分类数,active值,negtive值)

以这个数据集为例,每张图片是8*8的小图片,也就是有64个特征值,标签是图片对应的数字,即0 1 2 3.。。。

所以one_hot(y,10,1,0)之后:

1===》[0 1 0 0 0 0 0 0 0 0]

2===》[0 0 1 0 0 0 0 0 0 0]

。。。。。

这样我们就可以计算交叉熵了。对应分类问题,一般都是将结果表示为one-hot的向量。

第二个预处理是去掉颜色信息:

return (x-8)/16

直接简单粗暴: [x-(max-min)]/max

由于我知道最小值是0,最大值是16,所以直接套数上去了。

最后是部分运行结果:

setp 0,test_loss 567.181641
setp 1,test_loss 514.701660
setp 2,test_loss 476.468811
setp 3,test_loss 465.000031
setp 4,test_loss 495.855927
setp 5,test_loss 439.806610
setp 6,test_loss 381.180908

setp 2995,test_loss 16.674290
setp 2996,test_loss 17.192608
setp 2997,test_loss 14.386111
setp 2998,test_loss 12.961364
setp 2999,test_loss 14.622895
0.93770856

可以看到后面的训练损失值一值在抖动。不知道是为什么。

使用线性回归识别sklearn中的手写数字digit的更多相关文章

  1. 使用TensorFlow的卷积神经网络识别自己的单个手写数字,填坑总结

    折腾了几天,爬了大大小小若干的坑,特记录如下.代码在最后面. 环境: Python3.6.4 + TensorFlow 1.5.1 + Win7 64位 + I5 3570 CPU 方法: 先用MNI ...

  2. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  3. PyTorch基础——使用卷积神经网络识别手写数字

    一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...

  4. 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)

    @ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...

  5. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  6. TensorFlow实现Softmax Regression识别手写数字中"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败”问题

    出现问题: 在使用TensorFlow实现MNIST手写数字识别时,出现"TimeoutError: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应 ...

  7. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  8. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  9. KNN识别手写数字

    一.问题描述 手写数字被存储在EXCEL表格中,行表示一个数字的标签和该数字的像素值,有多少行就有多少个样本. 一共42000个样本 二.KNN KNN最邻近规则,主要应用领域是对未知事物的识别,即判 ...

随机推荐

  1. 【Oracle 】tablespace 表空间创建和管理

    1.表空间的概述 1. 表空间是数据库的逻辑组成部分. 2. 从物理上讲,数据库数据存放在数据文件中: 3. 从逻辑上讲,数据库是存放在表空间中,表空间由一个或者多个数据文件组成. 2.oracle的 ...

  2. WPF:“wpf类库项目改为Window应用程序项目”系列问题

    一.wpf类库项目改为Window应用程序项目1.错误 CS5001 Program does not contain a static 'Main' method suitable for an e ...

  3. Android 关于导航栏(虚拟按键)遮挡PopupWindow底部布局的问题

    我们自定义popupWindow的时候,一般会设置这些参数 setContentView(contentView); //设置高度为屏幕高度 setWidth(UIUtils.getScreenHei ...

  4. appium简明教程(5)——appium client方法一览

    appium client扩展了原生的webdriver client方法 下面以java代码为例,简单过一下appium client提供的适合移动端使用的新方法 resetApp() getApp ...

  5. Python 文件 write() 方法

    概述 Python 文件 write() 方法用于向文件中写入指定字符串. 在文件关闭前或缓冲区刷新前,字符串内容存储在缓冲区中,这时你在文件中是看不到写入的内容的. 语法 write() 方法语法如 ...

  6. Python MySQLdb 批量插入 封装

    def insert_data_many(dbName,list_data_dict): try: # 得到列表的第一个字典集合 data_dict = list_data_dict[0] # 得到( ...

  7. Java代码通过API操作HBase的最佳实践

    HBase提供了丰富的API.这使得用Java连接HBase非常方便. 有时候大家会使用HTable table=new HTable(config,tablename);的方式来实例化一个HTabl ...

  8. 批量修改Mysql数据库表Innodb为MyISAN

    mysql -uroot -e "SELECT concat('ALTER TABLE ', TABLE_NAME,' ENGINE=MYISAM;') FROM Information_s ...

  9. 关于java代码提交HTTP POST请求中文乱码的解决方法

    首先说明下这些只是根据我工作常用经验的总结,可能不一定完全对,也不一定全面,但却是最通用的. JAVA里HTTP提交方式 httpurlconnection:jdk里自带的 httpclient:ap ...

  10. 03.反射--01【反射机制】【反射的应用场景】【Tomcat服务器】

    https://blog.csdn.net/benjaminzhang666/article/details/9408611 https://blog.csdn.net/benjaminzhang66 ...