学习了tensorflow的线性回归。

首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室

import tensorflow as tf
import numpy as np
class linearRegressionModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self._index_in_epoch=0
self.constructModel()
self.sess=tf.Session()
self.sess.run(tf.global_variables_initializer())
#权重初始化
def weight_variable(self,shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#偏置项初始化
def bais_variable(self,shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#获取数据块,每次选100个样本,如果选完,则重新打乱
def next_batch(self,batch_size):
start=self._index_in_epoch
self._index_in_epoch+=batch_size
if self._index_in_epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self._datas=self._datas[perm]
self._labels=self._labels[perm]
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_datas
end=self._index_in_epoch
return self._datas[start:end],self._labels[start:end]
def constructModel(self):
self.x=tf.placeholder(tf.float32,[None,self.x_dimen])
self.y=tf.placeholder(tf.float32,[None,1])
self.w=self.weight_variable([self.x_dimen,1])
self.b=self.bais_variable([1])
self.y_prec=tf.nn.bias_add(tf.matmul(self.x,self.w),self.b)
mse=tf.reduce_mean(tf.squared_difference(self.y_prec,self.y))
l2=tf.reduce_mean(tf.square(self.w))
#self.loss=mse+0.15*l2
self.loss=mse
self.train_step=tf.train.AdamOptimizer(0.1).minimize(self.loss)
def train(self,x_train,y_train,x_test,y_test):
self._datas=x_train
self._labels=y_train
self._num_datas=x_train.shape[0]
for i in range(5000):
batch=self.next_batch(100)
self.sess.run(self.train_step,
feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if i%10==0:
train_loss=self.sess.run(self.loss,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
print("setp %d,test_loss %f"%(i,train_loss))
def predict_batch(self,arr,batchsize):
for i in range(0,len(arr),batchsize):
yield arr[i:i+batchsize]
def predict(self,x_predict):
pred_list=[]
for x_test_batch in self.predict_batch(x_predict,100):
pred =self.sess.run(self.y_prec,{self.x:x_test_batch})
pred_list.append(pred)
return np.vstack(pred_list)

仿照这个代码,联系使用线性回归的方法对mnist进行训练。开始选择学习率为0.1,结果训练失败,调节学习率为0.01.正确率在0.91左右

给出训练类:

import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm,:]
self.lables=self.lables[perm,:]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end,:],self.lables[start:end,:]
def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
self.lables=y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={ self.x: batch[0],
# self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))

然后是调用方法,包括了对这个mnist数据集的下载

from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) if __name__=='__main__': x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
linear = mlm(len(x_train[1]))
linear.train(x_train,y_train,x_test,y_test)

下载方法来自tensorflow的官方文档中文版

使用线性回归识别手写阿拉伯数字mnist数据集的更多相关文章

  1. stanford coursera 机器学习编程作业 exercise 3(使用神经网络 识别手写的阿拉伯数字(0-9))

    本作业使用神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于使用逻辑回归实现多分类问题:识别手写的阿拉伯数字(0-9),请参考:http://www.cnblogs.com ...

  2. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  3. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  4. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  5. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  6. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  7. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  8. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

  9. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

随机推荐

  1. 73条日常shell命令汇总,总有一条你需要!

    1.检查远程端口是否对bash开放: echo >/dev/tcp/8.8.8.8/53 && echo "open" 2.让进程转入后台: Ctrl + z ...

  2. 如何使两台机器不通过密码连接起来(linux)

    要求服务器10.96.22.40不通过密码直接连接服务器10.96.21.53 1:准备必须的软件 A:服务器40和53同时安装所需软件 yum -y install openssh-server o ...

  3. Android平台的音乐资源管理与播放

    Android平台基于Linux和开放手机联盟(OHA)系统,经过中国移动的创新研发,设计出拥有新颖独特的用户操作界面,增强了浏览器能力和WAP 兼容性,优化了多媒体领域的OpenCORE.浏览器领域 ...

  4. 开源的PaaS平台

    原文地址:https://blog.csdn.net/mypods/article/details/9366465 1.Stackato Stackato 是一个应用平台,用来创建私有.安全和灵活的企 ...

  5. Eclipse中Ant的配置与测试

    在Eclipse中使用Ant Ant是Java平台下非常棒的批处理命令执行程序,能非常方便地自动完成编译,测试,打包,部署等等一系列任务,大大提高开发效率.如果你现在还没有开始使用Ant,那就要赶快开 ...

  6. SimpleAdapter真不简单!

    作为一名编程初学者,我总是认为自己什么都不会,什么都不行,就算实现了文档指定的功能,我永远都是觉得自己写过的代码实在是太烂了,它只是恰巧能够运行而已!它只是在运行的时候恰巧没有发现错误而已!!一直都是 ...

  7. Vue 点击波浪特效指令

  8. 活久见: 原来 Chrome 浏览器支持 Import from 语法

    需要满足以下三个条件: 1.高版本的Chrome ,总而言之越新越好……,其他浏览器请参考:https://caniuse.com/#search=import 2.必须在服务器环境下才能运行,譬如a ...

  9. SQL Server中判断字符串出现的位置及字符串截取

    首先建一张测试表: )); insert into teststring values ('张三,李四,王五,马六,萧十一,皇宫'); 1.判断字符串中某字符(字符串)出现的次数,第一次出现的位置最后 ...

  10. Android自己定义ViewGroup(二)——带悬停标题的ExpandableListView

    项目里要加一个点击可收缩展开的列表,要求带悬停标题,详细效果例如以下图: watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fon ...