# coding:utf8
import numpy as np
import cPickle
import os
import tensorflow as tf class SoftMax:
def __init__(self,MAXT=30,step=0.0025):
self.MAXT = MAXT
self.step = step def load_theta(self,datapath="data/softmax.pkl"):
self.theta = cPickle.load(open(datapath,'rb')) def process_train(self,data,label,typenum=10,batch_size=500):
batches = data.shape[0] / batch_size
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = tf.Variable(tf.zeros([valuenum,typenum]))
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵
train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for epoch in range(self.MAXT):
cost_=[]
for index in xrange(batches):
c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
y_: label[index * batch_size: (index + 1) * batch_size]})
cost_.append(c_)
if epoch % 5 == 0:
print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
(epoch,index + 1,batches,np.mean(cost_)))
self.theta=sess.run(theta)
if not os.path.exists('data/softmax.pkl'):
f= open("data/softmax.pkl",'wb')
cPickle.dump(self.theta,f)
f.close()
return self.theta def process_test(self,data,label,typenum=10):
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = self.theta
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) def h(self,x):
m = np.exp(np.dot(x,self.theta))
sump = np.sum(m,axis=1)
return m/sump def predict(self,x):
return np.argmax(self.h(x),axis=1) def reshape_data(self,label,typenum):
label_=[]
for yl_ in label:
tl_=np.zeros(typenum)
tl_[yl_]=1.0
label_.append(tl_)
return np.mat(label_) if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data = np.array(training_inputs)
training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
vdata = np.array(training_inputs)
f.close() softmax = SoftMax()
softmax.process_train(data,training_data[1])
softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269
softmax.process_test(data,training_data[1]) #Accuracy: 0.92718

Softmax回归(使用tensorflow)的更多相关文章

  1. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  2. TensorFlow实现Softmax回归(模型存储与加载)

    # -*- coding: utf-8 -*- """ Created on Thu Oct 18 18:02:26 2018 @author: zhen "& ...

  3. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  4. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  5. 使用Softmax回归将神经网络输出转成概率分布

    神经网络解决多分类问题最常用的方法是设置n个输出节点,其中n为类别的个数.对于每一个样例,神经网络可以得到一个n维数组作为输出结果.数组中的每一个维度(也就是每一个输出节点)对应一个类别,通过前向传播 ...

  6. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  7. Softmax回归

    Reference: http://ufldl.stanford.edu/wiki/index.php/Softmax_regression http://deeplearning.net/tutor ...

  8. Softmax回归(Softmax Regression)

    转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...

  9. DeepLearning之路(二)SoftMax回归

    Softmax回归   1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...

  10. Machine Learning 学习笔记 (3) —— 泊松回归与Softmax回归

    本系列文章允许转载,转载请保留全文! [请先阅读][说明&总目录]http://www.cnblogs.com/tbcaaa8/p/4415055.html 1. 泊松回归 (Poisson ...

随机推荐

  1. MVP架构。。。。

    Model-View-Presenter(MVP)概述    MVC模式已经出现了几十年了,在GUI领域已经得到了广泛的应用,由于微软ASP.NET MVC Framework的出现,致使MVC一度成 ...

  2. stm32 dac 配置过程

    DAC模块的通道1来输出模拟电压,其详细设置步骤如下: 1)开启PA口时钟,设置PA4为模拟输入. STM32F103ZET6的DAC通道1是接在PA4上的,所以,我们先要使能PORTA的时钟,然后设 ...

  3. solr学习之入门篇

    一,简介 Solr是一个独立的企业级搜索应用服务器,它对外提供类似于Web-service的API接口.用户可以通过http请求,向搜索引擎服务器提交一定格式的XML文件,生成索引:也可以通过Http ...

  4. Java容器类概述

    1.简介 容器是一种在一个单元里处理一组复杂元素的对象.使用集合框架理论上能够减少编程工作量,提高程序的速度和质量,毕竟类库帮我们实现的集合在一定程度上时最优的.在Java中通过java.util为用 ...

  5. Git的环境搭建

    Git时当下流行的分布式版本控制系统. 集中式版本控制系统的版本库是集中存放在中央处理器的,所以开发者要先从中央服务器获取最新的版本,编码后再将自己的代码发送给中央处理器.集中式版本控制系统最大的缺点 ...

  6. iOS:Size Classes的使用

    iOS 8在应用界面的可视化设计上添加了一个新的特性-Size Classes,对于任何设备来说,界面的宽度和高度都只分为两种描述:正常和紧凑.这样开发者便可以无视设备具体的尺寸,而是对这两类和它们的 ...

  7. 16、SQL基础整理(触发器.方便备份)

    触发器(方便备份) 本质上还是一个存储过程,只不过不是通过exec来调用执行,而是通过增删改数据库的操作来执行(可以操作视图) 全部禁用触发器 alter table teacher disable ...

  8. BZOJ3028 食物 (生成函数)

    首先 1+x+x^2+x^3+...+x^∞=1/(1-x) 对于题目中的几种食物写出生成函数 (对于a*x^b , a表示方案数 x表示食物,b表示该种食物的个数) f(1)=1+x^2+x^4+. ...

  9. 清除浮动2-父元素设置overflow:hidden

    <!doctype html><html> <head> <meta charset="UTF-8"> <meta name= ...

  10. eclipse workspace出错而导致启动不了

    遇到这种问题也是第一次让我觉得eclipse没有vs强大. 说正题,关于怎么解决的问题,也是上网搜了一大堆.主要有两种吧: 1.就是删除这个workspace的整个metadata,这样就可以打开这个 ...