# coding:utf8
import numpy as np
import cPickle
import os
import tensorflow as tf class SoftMax:
def __init__(self,MAXT=30,step=0.0025):
self.MAXT = MAXT
self.step = step def load_theta(self,datapath="data/softmax.pkl"):
self.theta = cPickle.load(open(datapath,'rb')) def process_train(self,data,label,typenum=10,batch_size=500):
batches = data.shape[0] / batch_size
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = tf.Variable(tf.zeros([valuenum,typenum]))
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵
train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for epoch in range(self.MAXT):
cost_=[]
for index in xrange(batches):
c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
y_: label[index * batch_size: (index + 1) * batch_size]})
cost_.append(c_)
if epoch % 5 == 0:
print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
(epoch,index + 1,batches,np.mean(cost_)))
self.theta=sess.run(theta)
if not os.path.exists('data/softmax.pkl'):
f= open("data/softmax.pkl",'wb')
cPickle.dump(self.theta,f)
f.close()
return self.theta def process_test(self,data,label,typenum=10):
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = self.theta
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) def h(self,x):
m = np.exp(np.dot(x,self.theta))
sump = np.sum(m,axis=1)
return m/sump def predict(self,x):
return np.argmax(self.h(x),axis=1) def reshape_data(self,label,typenum):
label_=[]
for yl_ in label:
tl_=np.zeros(typenum)
tl_[yl_]=1.0
label_.append(tl_)
return np.mat(label_) if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data = np.array(training_inputs)
training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
vdata = np.array(training_inputs)
f.close() softmax = SoftMax()
softmax.process_train(data,training_data[1])
softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269
softmax.process_test(data,training_data[1]) #Accuracy: 0.92718

Softmax回归(使用tensorflow)的更多相关文章

  1. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  2. TensorFlow实现Softmax回归(模型存储与加载)

    # -*- coding: utf-8 -*- """ Created on Thu Oct 18 18:02:26 2018 @author: zhen "& ...

  3. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  4. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  5. 使用Softmax回归将神经网络输出转成概率分布

    神经网络解决多分类问题最常用的方法是设置n个输出节点,其中n为类别的个数.对于每一个样例,神经网络可以得到一个n维数组作为输出结果.数组中的每一个维度(也就是每一个输出节点)对应一个类别,通过前向传播 ...

  6. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  7. Softmax回归

    Reference: http://ufldl.stanford.edu/wiki/index.php/Softmax_regression http://deeplearning.net/tutor ...

  8. Softmax回归(Softmax Regression)

    转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...

  9. DeepLearning之路(二)SoftMax回归

    Softmax回归   1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...

  10. Machine Learning 学习笔记 (3) —— 泊松回归与Softmax回归

    本系列文章允许转载,转载请保留全文! [请先阅读][说明&总目录]http://www.cnblogs.com/tbcaaa8/p/4415055.html 1. 泊松回归 (Poisson ...

随机推荐

  1. C#获取项目程序及运行路径的方法

    1.asp.net webform用“Request.PhysicalApplicationPath获取站点所在虚拟目录的物理路径,最后包含“\”: 2.c# winform用 A:“Applicat ...

  2. hdu 2077

    PS:汉诺塔问题....找规律...观察发现,先是小的移动到B,然后大的移动到C(两步),然后小的移动到C,完成.刚开始就以为是f(n)=2f(n-1)+2..然而,小的移动一步是需要f(n)=3f( ...

  3. ISO c++11 does not allow conversion from string literal to 'char*'

    http://stackoverflow.com/questions/9650058/deprecated-conversion-from-string-literal-to-char

  4. IOS源码封装成.bundle和.a文件,以及加入xib的具体方法,翻遍网络,仅此一家完美翻译!! IOS7!!(3) 完美结局

    以上翻译有误解之处,现在简单做法如下: 经过深入研究,才感觉明白了内部机制,现在简单介绍于下,主要步骤:xcode5 创建库项目,删掉测试文件和默认创建的类,添加viewController类带xib ...

  5. Android Priority Job Queue (Job Manager):线程任务的容错重启机制(二)

     Android Priority Job Queue (Job Manager):线程任务的容错重启机制(二) 附录文章4简单介绍了如何启动一个后台线程任务,Android Priority J ...

  6. 11、C#基础整理(特殊集合和哈希表)

    特殊集合:队列.栈 一.栈Stack类:先进后出,没有索引 Stack ss = new Stack(); 1.增加数据:push :将元素推入集合 ss.Push(); ss.Push(); ss. ...

  7. HDU 5644 (费用流)

    Problem King's Pilots (HDU 5644) 题目大意 举办一次持续n天的飞行表演,第i天需要Pi个飞行员.共有m种休假计划,每个飞行员表演1次后,需要休假Si天,并提供Ti报酬来 ...

  8. Unity3D 创建动态的立方体图系统

    Unity3D 创建动态的立方体图系统 这一篇主要是利用上一篇的Shader,通过脚本来完成一个动态的立方体图变化系统. 准备工作如下: 创建一个新的场景.一个球体.提供给场景一个平行光,准备2个立方 ...

  9. Unity3d Web Player 与server端联网配置

    针对Unity3d Web Player 的server端联网配置写一随笔咯.  以SmartFoxServer2X官方的Unity3d Example ”tris“为例,部署好服务器之后,在Unit ...

  10. BOOL布尔类型

    1.BOOL数据类型,是一种表示非真即假的数据类型,布尔类型的变量只有YES和NO两个值.YES表⽰示表达式结果为真,NO表示表达式结果为假. 2.在C语言中,认为非0即为真. 3.分⽀支语句中,经常 ...