Softmax回归(使用tensorflow)
# coding:utf8
import numpy as np
import cPickle
import os
import tensorflow as tf class SoftMax:
def __init__(self,MAXT=30,step=0.0025):
self.MAXT = MAXT
self.step = step def load_theta(self,datapath="data/softmax.pkl"):
self.theta = cPickle.load(open(datapath,'rb')) def process_train(self,data,label,typenum=10,batch_size=500):
batches = data.shape[0] / batch_size
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = tf.Variable(tf.zeros([valuenum,typenum]))
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵
train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for epoch in range(self.MAXT):
cost_=[]
for index in xrange(batches):
c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
y_: label[index * batch_size: (index + 1) * batch_size]})
cost_.append(c_)
if epoch % 5 == 0:
print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
(epoch,index + 1,batches,np.mean(cost_)))
self.theta=sess.run(theta)
if not os.path.exists('data/softmax.pkl'):
f= open("data/softmax.pkl",'wb')
cPickle.dump(self.theta,f)
f.close()
return self.theta def process_test(self,data,label,typenum=10):
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = self.theta
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) def h(self,x):
m = np.exp(np.dot(x,self.theta))
sump = np.sum(m,axis=1)
return m/sump def predict(self,x):
return np.argmax(self.h(x),axis=1) def reshape_data(self,label,typenum):
label_=[]
for yl_ in label:
tl_=np.zeros(typenum)
tl_[yl_]=1.0
label_.append(tl_)
return np.mat(label_) if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data = np.array(training_inputs)
training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
vdata = np.array(training_inputs)
f.close() softmax = SoftMax()
softmax.process_train(data,training_data[1])
softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269
softmax.process_test(data,training_data[1]) #Accuracy: 0.92718
Softmax回归(使用tensorflow)的更多相关文章
- 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...
- TensorFlow实现Softmax回归(模型存储与加载)
# -*- coding: utf-8 -*- """ Created on Thu Oct 18 18:02:26 2018 @author: zhen "& ...
- 利用TensorFlow识别手写的数字---基于Softmax回归
1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...
- 统计学习方法:罗杰斯特回归及Tensorflow入门
作者:桂. 时间:2017-04-21 21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...
- 使用Softmax回归将神经网络输出转成概率分布
神经网络解决多分类问题最常用的方法是设置n个输出节点,其中n为类别的个数.对于每一个样例,神经网络可以得到一个n维数组作为输出结果.数组中的每一个维度(也就是每一个输出节点)对应一个类别,通过前向传播 ...
- Haskell手撸Softmax回归实现MNIST手写识别
Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...
- Softmax回归
Reference: http://ufldl.stanford.edu/wiki/index.php/Softmax_regression http://deeplearning.net/tutor ...
- Softmax回归(Softmax Regression)
转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...
- DeepLearning之路(二)SoftMax回归
Softmax回归 1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...
- Machine Learning 学习笔记 (3) —— 泊松回归与Softmax回归
本系列文章允许转载,转载请保留全文! [请先阅读][说明&总目录]http://www.cnblogs.com/tbcaaa8/p/4415055.html 1. 泊松回归 (Poisson ...
随机推荐
- MVP架构。。。。
Model-View-Presenter(MVP)概述 MVC模式已经出现了几十年了,在GUI领域已经得到了广泛的应用,由于微软ASP.NET MVC Framework的出现,致使MVC一度成 ...
- stm32 dac 配置过程
DAC模块的通道1来输出模拟电压,其详细设置步骤如下: 1)开启PA口时钟,设置PA4为模拟输入. STM32F103ZET6的DAC通道1是接在PA4上的,所以,我们先要使能PORTA的时钟,然后设 ...
- solr学习之入门篇
一,简介 Solr是一个独立的企业级搜索应用服务器,它对外提供类似于Web-service的API接口.用户可以通过http请求,向搜索引擎服务器提交一定格式的XML文件,生成索引:也可以通过Http ...
- Java容器类概述
1.简介 容器是一种在一个单元里处理一组复杂元素的对象.使用集合框架理论上能够减少编程工作量,提高程序的速度和质量,毕竟类库帮我们实现的集合在一定程度上时最优的.在Java中通过java.util为用 ...
- Git的环境搭建
Git时当下流行的分布式版本控制系统. 集中式版本控制系统的版本库是集中存放在中央处理器的,所以开发者要先从中央服务器获取最新的版本,编码后再将自己的代码发送给中央处理器.集中式版本控制系统最大的缺点 ...
- iOS:Size Classes的使用
iOS 8在应用界面的可视化设计上添加了一个新的特性-Size Classes,对于任何设备来说,界面的宽度和高度都只分为两种描述:正常和紧凑.这样开发者便可以无视设备具体的尺寸,而是对这两类和它们的 ...
- 16、SQL基础整理(触发器.方便备份)
触发器(方便备份) 本质上还是一个存储过程,只不过不是通过exec来调用执行,而是通过增删改数据库的操作来执行(可以操作视图) 全部禁用触发器 alter table teacher disable ...
- BZOJ3028 食物 (生成函数)
首先 1+x+x^2+x^3+...+x^∞=1/(1-x) 对于题目中的几种食物写出生成函数 (对于a*x^b , a表示方案数 x表示食物,b表示该种食物的个数) f(1)=1+x^2+x^4+. ...
- 清除浮动2-父元素设置overflow:hidden
<!doctype html><html> <head> <meta charset="UTF-8"> <meta name= ...
- eclipse workspace出错而导致启动不了
遇到这种问题也是第一次让我觉得eclipse没有vs强大. 说正题,关于怎么解决的问题,也是上网搜了一大堆.主要有两种吧: 1.就是删除这个workspace的整个metadata,这样就可以打开这个 ...