# coding:utf8
import numpy as np
import cPickle
import os
import tensorflow as tf class SoftMax:
def __init__(self,MAXT=30,step=0.0025):
self.MAXT = MAXT
self.step = step def load_theta(self,datapath="data/softmax.pkl"):
self.theta = cPickle.load(open(datapath,'rb')) def process_train(self,data,label,typenum=10,batch_size=500):
batches = data.shape[0] / batch_size
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = tf.Variable(tf.zeros([valuenum,typenum]))
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵
train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for epoch in range(self.MAXT):
cost_=[]
for index in xrange(batches):
c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
y_: label[index * batch_size: (index + 1) * batch_size]})
cost_.append(c_)
if epoch % 5 == 0:
print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
(epoch,index + 1,batches,np.mean(cost_)))
self.theta=sess.run(theta)
if not os.path.exists('data/softmax.pkl'):
f= open("data/softmax.pkl",'wb')
cPickle.dump(self.theta,f)
f.close()
return self.theta def process_test(self,data,label,typenum=10):
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = self.theta
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) def h(self,x):
m = np.exp(np.dot(x,self.theta))
sump = np.sum(m,axis=1)
return m/sump def predict(self,x):
return np.argmax(self.h(x),axis=1) def reshape_data(self,label,typenum):
label_=[]
for yl_ in label:
tl_=np.zeros(typenum)
tl_[yl_]=1.0
label_.append(tl_)
return np.mat(label_) if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data = np.array(training_inputs)
training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
vdata = np.array(training_inputs)
f.close() softmax = SoftMax()
softmax.process_train(data,training_data[1])
softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269
softmax.process_test(data,training_data[1]) #Accuracy: 0.92718

Softmax回归(使用tensorflow)的更多相关文章

  1. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  2. TensorFlow实现Softmax回归(模型存储与加载)

    # -*- coding: utf-8 -*- """ Created on Thu Oct 18 18:02:26 2018 @author: zhen "& ...

  3. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  4. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  5. 使用Softmax回归将神经网络输出转成概率分布

    神经网络解决多分类问题最常用的方法是设置n个输出节点,其中n为类别的个数.对于每一个样例,神经网络可以得到一个n维数组作为输出结果.数组中的每一个维度(也就是每一个输出节点)对应一个类别,通过前向传播 ...

  6. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  7. Softmax回归

    Reference: http://ufldl.stanford.edu/wiki/index.php/Softmax_regression http://deeplearning.net/tutor ...

  8. Softmax回归(Softmax Regression)

    转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...

  9. DeepLearning之路(二)SoftMax回归

    Softmax回归   1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...

  10. Machine Learning 学习笔记 (3) —— 泊松回归与Softmax回归

    本系列文章允许转载,转载请保留全文! [请先阅读][说明&总目录]http://www.cnblogs.com/tbcaaa8/p/4415055.html 1. 泊松回归 (Poisson ...

随机推荐

  1. c++父类指针强制转为子类指针后的测试(帮助理解指针访问成员的本质)(反多态)

    看下面例子: #include "stdafx.h" #include <iostream> class A {  //父类 public: void  f()   / ...

  2. UNICODE字符集(20140520)

    1多字节字符集,如"IT学吧",sizeof内存长度为7,因为前面2个字母各占用一个字节,后面两个汉字各占用2个字节,结尾的\0占用一个字节.strlen即字符串长度的结果为6. ...

  3. IIS6的session丢失问题

    解决办法:      a IIS6中相比IIS5增加了一个应用程序池,默认是使用DefaultAppPool.      b   先为站点建立一个应用程序池,打开IIS管理器,右键点击应用程序池-新建 ...

  4. 又见蒙特卡洛——python模拟解决三门问题

    三门问题很有意思,wiki用不同方法将原理讲的很透彻了,我跟喜欢其中这种理解方式:无论参赛者开始的选择如何,在被主持人问到是否更换时都选择更换.如果参赛者先选中山羊,换之后百分之百赢:如果参赛者先选中 ...

  5. 使用httputils上传图片到服务器

    //创建httpUtils对象 HttpUtils mRegHttpUtils = new HttpUtils(); //图片路径 String path = "/sdcard/Downlo ...

  6. Mac OS X中配置Apache

    我使用的Mac OS X版本是10.8.2,Mac自带了Apache环境. 启动Apache 设置虚拟主机 启动Apache 打开“终端(terminal)”,输入 sudo apachectl -v ...

  7. 11、网页制作Dreamweaver(补充:JS零碎知识点&&正则表达式)

    JS知识点 回车符/r和换行符/n的区别:/r 相当于enter,是段落与段落之间的区别, /n 相当于shift+enter,是行与行之间距离,比较小 几种window操作方法: 1.获取当前窗口大 ...

  8. 【转】终于干了点正事。。三天用了三个库opencv、emgu、aforge.net[2011.7.30]

    原文转自: http://blog.csdn.net/tutuguaiguai0427/article/details/6646051 这阵子,确切说这几天,还是看了好多东西的.虽然无用功居多. 上篇 ...

  9. js优化提升访问速度

    一.给JS文件减肥. 有的人为了给网站增加炫目效果,往往会使用一些JS效果代码,这在上个世纪似乎还很流行,对于现在来说,最好在用户体验确实需要的情况下,使用这些东西.至于希望给自己的JS文件减肥的童鞋 ...

  10. magento添加分类属性

    在magento中给产品添加自定义属性是很容易实现在后台就可以很轻易添加,但是给分类就不行了,magento本身没有提供给category添加自定义属性.在实际的运用过程中我们想给cagegory添加 ...