Tensorflow-逻辑斯蒂回归
1.交叉熵
逻辑斯蒂回归这个模型采用的是交叉熵,通俗点理解交叉熵
推荐一篇文章讲的很清楚:
https://www.zhihu.com/question/41252833
因此,交叉熵越低,这个策略就越好,最低的交叉熵也就是使用了真实分布所计算出来的信息熵,因为此时 ,交叉熵 = 信息熵。这也是为什么在机器学习中的分类算法中,我们总是最小化交叉熵,因为交叉熵越低,就证明由算法所产生的策略最接近最优策略,也间接证明我们算法所算出的非真实分布越接近真实分布
2.代码解释
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import tensorflow as tf
# 样本集
from tensorflow.examples.tutorials.mnist import input_data # 加载数据,目标值变成概率的形式,ont-hot
mnist = input_data.read_data_sets('./',one_hot=True)
# 训练数据 (55000, 784)
mnist.train.images.shape
# 测试数据 (10000, 784)
mnist.test.images.shape
# 目标值 ont-hot形式
mnist.train.labels[:10] # 构建方程
X = tf.placeholder(dtype=tf.float64,shape = (None,784),name = 'data')
y = tf.placeholder(dtype=tf.float64,shape = (None,10),name = 'target')
W = tf.Variable(initial_value=tf.zeros(shape =(784,10),dtype = tf.float64))
b = tf.Variable(initial_value=tf.zeros(shape = (10),dtype = tf.float64))
y_pred = tf.matmul(X,W) + b # 构建损失函数
# y 和 y_pred对比
# y表示是概率 [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]
# y_pred,矩阵运算求解的目标值
# 要将y_pred转化成概率,softmax
y_ = tf.nn.softmax(y_pred)
# 此时y和y_表示概率
# y和y_越接近,说明预测函数越准确
# 此时分类问题,交叉熵,表示损失函数
# 熵:表示的系统混乱程度
# 损失函数,越小越好
# 平均交叉熵------->可以比较大小的数
loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(y,tf.log(1/y_)),axis = -1)) # 最优化
opt = tf.train.GradientDescentOptimizer(0.01).minimize(loss) # 训练
# 训练次数
epoches = 100
# 保存
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(epoches):
c = 0
for j in range(100):
# 一次取550个,分100次取完数据 next_batch方法
X_train,y_train = mnist.train.next_batch(550)
opt_,cost = sess.run([opt,loss],feed_dict = {X:X_train,y:y_train})
c += cost/100
# 计算准确率
X_test,y_test = mnist.test.next_batch(2000)
y_predict = sess.run(y_,feed_dict={X:X_test})
y_test = np.argmax(y_test,axis = -1)
y_predict = np.argmax(y_predict,axis = 1)
accuracy = (y_test == y_predict).mean()
print('执行次数:%d。损失函数是:%0.4f。准确率是:%0.4f'%(i+1,c,accuracy))
if accuracy > 0.91:
saver.save(sess,'./model/estimator',global_step=i)
print('---------------------------模型保存成功----------------------------')
保存了模型,在上一次的基础上继续进行学习,这样的话可以直接从上次的准确率开始
# 其实代码是一样的,只是加了个saver.restore还原
with tf.Session() as sess:
# 还原到sess会话中
saver.restore(sess,'./model/estimator-99') for i in range(100,200):
c = 0
for j in range(100):
X_train,y_train = mnist.train.next_batch(550)
opt_,cost = sess.run([opt,loss],feed_dict = {X:X_train,y:y_train})
c += cost/100 # 计算准确率
X_test,y_test = mnist.test.next_batch(2000) y_predict = sess.run(y_,feed_dict={X:X_test}) y_test = np.argmax(y_test,axis = -1)
y_predict = np.argmax(y_predict,axis = 1) accuracy = (y_test == y_predict).mean()
print('执行次数:%d。损失函数是:%0.4f。准确率是:%0.4f'%(i+1,c,accuracy)) if accuracy > 0.91:
saver.save(sess,'./model/estimator',global_step=i)
print('---------------------------模型保存成功----------------------------')
Tensorflow-逻辑斯蒂回归的更多相关文章
- [置顶] 局部加权回归、最小二乘的概率解释、逻辑斯蒂回归、感知器算法——斯坦福ML公开课笔记3
转载请注明:http://blog.csdn.net/xinzhangyanxiang/article/details/9113681 最近在看Ng的机器学习公开课,Ng的讲法循循善诱,感觉提高了不少 ...
- 【分类器】感知机+线性回归+逻辑斯蒂回归+softmax回归
一.感知机 详细参考:https://blog.csdn.net/wodeai1235/article/details/54755735 1.模型和图像: 2.数学定义推导和优化: 3.流程 ...
- 【转】机器学习笔记之(3)——Logistic回归(逻辑斯蒂回归)
原文链接:https://blog.csdn.net/gwplovekimi/article/details/80288964 本博文为逻辑斯特回归的学习笔记.由于仅仅是学习笔记,水平有限,还望广大读 ...
- 机器学习之LinearRegression与Logistic Regression逻辑斯蒂回归(三)
一 评价尺度 sklearn包含四种评价尺度 1 均方差(mean-squared-error) 2 平均绝对值误差(mean_absolute_error) 3 可释方差得分(explained_v ...
- spark机器学习从0到1逻辑斯蒂回归之(四)
逻辑斯蒂回归 一.概念 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型.logistic回归的因变量可以是二分类的,也可以是多分类的.logis ...
- python机器学习实现逻辑斯蒂回归
逻辑斯蒂回归 关注公众号"轻松学编程"了解更多. [关键词]Logistics函数,最大似然估计,梯度下降法 1.Logistics回归的原理 利用Logistics回归进行分类的 ...
- 【项目实战】pytorch实现逻辑斯蒂回归
视频指导:https://www.bilibili.com/video/BV1Y7411d7Ys?p=6 一些数据集 在pytorch框架下,里面面有配套的数据集,pytorch里面有一个torchv ...
- 【TensorFlow入门完全指南】模型篇·逻辑斯蒂回归模型
import库,加载mnist数据集. 设置学习率,迭代次数,batch并行计算数量,以及log显示. 这里设置了占位符,输入是batch * 784的矩阵,由于是并行计算,所以None实际上代表并行 ...
- 逻辑斯蒂回归(Logistic Regression)
逻辑回归名字比较古怪,看上去是回归,却是一个简单的二分类模型. 逻辑回归的模型是如下形式: 其中x是features,θ是feature的权重,σ是sigmoid函数.将θ0视为θ0*x0(x0取值为 ...
- 逻辑斯蒂回归VS决策树VS随机森林
LR 与SVM 不同 1.logistic regression适合需要得到一个分类概率的场景,SVM则没有分类概率 2.LR其实同样可以使用kernel,但是LR没有support vector在计 ...
随机推荐
- Pytest权威教程21-API参考-02-标记(Marks)
目录 标记(Marks) pytest.mark.filterwarnings pytest.mark.parametrize pytest.mark.skip pytest.mark.skipif ...
- ubuntu之路——day2
一:sougou输入法安装 详情参考:https://blog.csdn.net/xin17863935225/article/details/82285177 注意切换成fcitx架构 因为linu ...
- lol英雄时刻
2019.5.30 翻出了当年人生中第一次LOL五杀截图...我用佐伊拿五杀了! 第一次五杀超激动的
- argmin ,argmax函数
在数学中,ARG MAX(或ARGMAX)代表最大值,即给定参数的点集,给定表达式的值达到其最大值: 换一种说法, 是f(x)具有最大值M的x的值的集合.例如,如果f(x)是1- | x |,那么它在 ...
- uSurvival 1.41多人在线生存逃杀吃鸡类游戏源码
uSurvival - the new Multiplayer Survival Asset from the creator of uMMORPG. Features:* Kill Zombies ...
- Spatial-Temporal Relation Networks for Multi-Object Tracking
Spatial-Temporal Relation Networks for Multi-Object Tracking 2019-05-21 11:07:49 Paper: https://arxi ...
- 【idea】断点调试时查看所有变量和静态变量
转载至博客:https://blog.csdn.net/qq32933432/article/details/86672341 缘起 笔者在进行HashMap原理探索的时候需要在IntelliJ ID ...
- windows中kill端口为8080的进程(或子进程)
1.netstat -aon|findstr "8080" 查出8080端口被9436进程占用 2.查看PID对应的进程 tasklist|findstr "9436&q ...
- 查看Linux系统的USB设备
查看Linux系统的USB设备 lsusb (centos没有该命令) dmesg (内核日志会输出) 执行dmesg
- log4j实现日志自动清理功能
log4j不支持自动清理功能,但是log4j2版本支持,log4j2是log4j的升级版,比logback先进. log4j升级为log4j2(不需要改动代码)https://blog.csdn.ne ...