tensorflow学习笔记五----------逻辑回归
在逻辑回归中使用mnist数据集。导入相应的包以及数据集。
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print ("MNIST loaded")
使用tensorflow中函数进行逻辑回归的构建。调用softmax函数进行逻辑回归模型的构建;构造损失函数【y*tf.log(actv)】;构造梯度下降训练器;
x = tf.placeholder("float", [None, 784]) #784代表照片像素为28*28 10代表共有十个数字
y = tf.placeholder("float", [None, 10]) # None is for infinite
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# LOGISTIC REGRESSION MODEL
#get the predict number
actv = tf.nn.softmax(tf.matmul(x, W) + b)
# COST FUNCTION
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv), reduction_indices=1))
# OPTIMIZER
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# PREDICTION
pred = tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))
# ACCURACY
accr = tf.reduce_mean(tf.cast(pred, "float"))
# INITIALIZER
init = tf.global_variables_initializer()
循环五十次,每五次打印一次结果,每次训练取100个样本
training_epochs = 50
batch_size = 100
display_step = 5
# SESSION
sess = tf.Session()
sess.run(init)
# MINI-BATCH LEARNING
for epoch in range(training_epochs):
avg_cost = 0.
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys})
feeds = {x: batch_xs, y: batch_ys}
avg_cost += sess.run(cost, feed_dict=feeds)/num_batch
# DISPLAY
if epoch % display_step == 0:
feeds_train = {x: batch_xs, y: batch_ys}
feeds_test = {x: mnist.test.images, y: mnist.test.labels}
train_acc = sess.run(accr, feed_dict=feeds_train)
test_acc = sess.run(accr, feed_dict=feeds_test)
print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))
print ("DONE")
tensorflow学习笔记五----------逻辑回归的更多相关文章
- tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)
mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...
- Python学习笔记之逻辑回归
# -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- tensorflow学习笔记(4)-学习率
tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...
- tensorflow学习笔记(2)-反向传播
tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- TensorFlow学习笔记10-卷积网络
卷积网络 卷积神经网络(Convolutional Neural Network,CNN)专门处理具有类似网格结构的数据的神经网络.如: 时间序列数据(在时间轴上有规律地采样形成的一维网格): 图像数 ...
随机推荐
- [jvm学习笔记]-类加载过程
JVM类加载的过程 加载=>验证=>准备=>解析=>初始化 5个阶段所执行的具体动作 加载 在加载阶段,虚拟机需要完成3个事情1.通过一个类的全限定名获取定义此类的二进制字节流 ...
- [luogu]P1070 道路游戏[DP]
[luogu]P1070 道路游戏 题目描述小新正在玩一个简单的电脑游戏.游戏中有一条环形马路,马路上有 n 个机器人工厂,两个相邻机器人工厂之间由一小段马路连接.小新以某个机器人工厂为起点,按顺时针 ...
- 2-sat问题简记
关于2-sat问题,这里笔者主要是做一些简记,如要详细了解,可以读一读此dalao的文章:https://blog.csdn.net/jarjingx/article/details/8521690 ...
- 说下Java堆空间结构,及常用的jvm内存分析命令和工具
Java堆空间结构图:http://www.cnblogs.com/SaraMoring/p/5713732.html JVM内存状况查看方法和分析工具: http://blog.csdn.net/n ...
- layui select动态添加option
<form class="layui-form" action=""> <div class="layui-form-item pr ...
- SpringMVC - <mvc:default-servlet-handler/> 导致 Controller失效
原文地址:http://blog.csdn.net/j080624/article/details/66969987
- maven使用常见问题
1.我写的是src/main/java/config/mybatis-cofig.xml 但总是报错 Could not find resource src/main/java/config/myba ...
- tensorflow源码分析——CTC
CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurren ...
- c/c++二级指针动态开辟内存
c版: #include <stdio.h> #include <stdlib.h> #define row 4 #define col 8 int main() { int ...
- 555E Case of Computer Network
分析 一个连通块内的肯定不影响 于是我们先缩点 之后对于每个路径 向上向下分别开一个差分数组 如果两个数组同时有值则不合法 代码 #include<bits/stdc++.h> using ...