学习了tensorflow的线性回归。

首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室

import tensorflow as tf
import numpy as np
class linearRegressionModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self._index_in_epoch=0
self.constructModel()
self.sess=tf.Session()
self.sess.run(tf.global_variables_initializer())
#权重初始化
def weight_variable(self,shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#偏置项初始化
def bais_variable(self,shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#获取数据块,每次选100个样本,如果选完,则重新打乱
def next_batch(self,batch_size):
start=self._index_in_epoch
self._index_in_epoch+=batch_size
if self._index_in_epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self._datas=self._datas[perm]
self._labels=self._labels[perm]
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_datas
end=self._index_in_epoch
return self._datas[start:end],self._labels[start:end]
def constructModel(self):
self.x=tf.placeholder(tf.float32,[None,self.x_dimen])
self.y=tf.placeholder(tf.float32,[None,1])
self.w=self.weight_variable([self.x_dimen,1])
self.b=self.bais_variable([1])
self.y_prec=tf.nn.bias_add(tf.matmul(self.x,self.w),self.b)
mse=tf.reduce_mean(tf.squared_difference(self.y_prec,self.y))
l2=tf.reduce_mean(tf.square(self.w))
#self.loss=mse+0.15*l2
self.loss=mse
self.train_step=tf.train.AdamOptimizer(0.1).minimize(self.loss)
def train(self,x_train,y_train,x_test,y_test):
self._datas=x_train
self._labels=y_train
self._num_datas=x_train.shape[0]
for i in range(5000):
batch=self.next_batch(100)
self.sess.run(self.train_step,
feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if i%10==0:
train_loss=self.sess.run(self.loss,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
print("setp %d,test_loss %f"%(i,train_loss))
def predict_batch(self,arr,batchsize):
for i in range(0,len(arr),batchsize):
yield arr[i:i+batchsize]
def predict(self,x_predict):
pred_list=[]
for x_test_batch in self.predict_batch(x_predict,100):
pred =self.sess.run(self.y_prec,{self.x:x_test_batch})
pred_list.append(pred)
return np.vstack(pred_list)

仿照这个代码,联系使用线性回归的方法对mnist进行训练。开始选择学习率为0.1,结果训练失败,调节学习率为0.01.正确率在0.91左右

给出训练类:

import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm,:]
self.lables=self.lables[perm,:]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end,:],self.lables[start:end,:]
def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
self.lables=y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={ self.x: batch[0],
# self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))

然后是调用方法,包括了对这个mnist数据集的下载

from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) if __name__=='__main__': x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
linear = mlm(len(x_train[1]))
linear.train(x_train,y_train,x_test,y_test)

下载方法来自tensorflow的官方文档中文版

使用线性回归识别手写阿拉伯数字mnist数据集的更多相关文章

  1. stanford coursera 机器学习编程作业 exercise 3(使用神经网络 识别手写的阿拉伯数字(0-9))

    本作业使用神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于使用逻辑回归实现多分类问题:识别手写的阿拉伯数字(0-9),请参考:http://www.cnblogs.com ...

  2. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  3. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  4. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  5. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  6. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  7. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  8. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

  9. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

随机推荐

  1. zabbix主动被动模式说明/区别

    说明:使用zabbix代理有很多好处,一方面可以监控不可达的远程区域:另一方面当监控项目数以万计的时候使用代理可以有效分担zabbix server压力,也简化分布式监控的维护. 具体:主动.被动模式 ...

  2. 神文章1:去年(2011)一年干了些啥? -vivo神人

    评论: 来自豆瓣的vivo神人,之前不知道有着一号牛逼的人物,觉此人博学.有正义感,其中有一片文章述说了中国近代经济演变历史情况,于我感触很深.因时间关系,没通读,有时间一定读完(微博口水杂录简略看了 ...

  3. Android之Activity系列总结(三)--Activity的四种启动模式

    一.返回栈简介 任务是指在执行特定作业时与用户交互的一系列 Activity. 这些 Activity 按照各自的打开顺序排列在堆栈(即返回栈,也叫任务栈)中. 首先介绍一下任务栈: (1)程序打开时 ...

  4. mongoDB 32位 安装包地址

    https://www.mongodb.org/dl/win32/i386 http://downloads.mongodb.org/win32/mongodb-win32-i386-3.2.4-si ...

  5. FaceBook登陆API -- Login with API calls

    Login with API calls Related Topics Understanding sessions FBSession Error handling FBError FBLoginC ...

  6. AndroidStudio编译错误:Error: null value in entry: blameLogFolder=null

    今天写项目的时候,电脑开了个WiFi热点,然后这个热点和window驱动不兼容,有时候会导致电脑重启,重启之后AndroidStudio编译就报错了, Error: null value in ent ...

  7. 转: 使用Hystrix实现自动降级与依赖隔离

    使用Hystrix实现自动降级与依赖隔离 原创 2017年06月25日 17:28:01 标签: 异步 / 降级 869 这篇文章是记录了自己的一次集成Hystrix的经验,原本写在公司内部wiki里 ...

  8. Google Guava中的前置条件

    前置条件:让方法调用的前置条件判断更简单. Guava在Preconditions类中提供了若干前置条件判断的实用方法,我们建议[在Eclipse中静态导入这些方法]每个方法都有三个变种: check ...

  9. Jmeter录制HTTPS

    Jmeter有录制功能,录制HTTPs需要增加一个证书配置,录制步骤如下: 1.打开jmeter,添加线程组.线程组右键,逻辑控制器>录制控制器 工作台 右键 非测试元件 >HTTP代理服 ...

  10. 让MySQL在美国标准下运行

    [美国标准下运行的MySQL会有哪方面的调整] 我不得不说,这里有点标题党了:事实上我想说的就是--ansi模式下启动mysqld进行,但是这个ansi我没有找到更好的译文,就给译成了“美国标准”了. ...