在逻辑回归中使用mnist数据集。导入相应的包以及数据集。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print ("MNIST loaded")

使用tensorflow中函数进行逻辑回归的构建。调用softmax函数进行逻辑回归模型的构建;构造损失函数【y*tf.log(actv)】;构造梯度下降训练器;

x = tf.placeholder("float", [None, 784]) #784代表照片像素为28*28 10代表共有十个数字
y = tf.placeholder("float", [None, 10]) # None is for infinite
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# LOGISTIC REGRESSION MODEL
#get the predict number
actv = tf.nn.softmax(tf.matmul(x, W) + b)
# COST FUNCTION
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv), reduction_indices=1))
# OPTIMIZER
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# PREDICTION
pred = tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))
# ACCURACY
accr = tf.reduce_mean(tf.cast(pred, "float"))
# INITIALIZER
init = tf.global_variables_initializer()

循环五十次,每五次打印一次结果,每次训练取100个样本

training_epochs = 50
batch_size = 100
display_step = 5
# SESSION
sess = tf.Session()
sess.run(init)
# MINI-BATCH LEARNING
for epoch in range(training_epochs):
avg_cost = 0.
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys})
feeds = {x: batch_xs, y: batch_ys}
avg_cost += sess.run(cost, feed_dict=feeds)/num_batch
# DISPLAY
if epoch % display_step == 0:
feeds_train = {x: batch_xs, y: batch_ys}
feeds_test = {x: mnist.test.images, y: mnist.test.labels}
train_acc = sess.run(accr, feed_dict=feeds_train)
test_acc = sess.run(accr, feed_dict=feeds_test)
print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))
print ("DONE")

tensorflow学习笔记五----------逻辑回归的更多相关文章

  1. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  2. Python学习笔记之逻辑回归

    # -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  4. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  5. tensorflow学习笔记(4)-学习率

    tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...

  6. tensorflow学习笔记(2)-反向传播

    tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...

  7. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  8. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  9. TensorFlow学习笔记10-卷积网络

    卷积网络 卷积神经网络(Convolutional Neural Network,CNN)专门处理具有类似网格结构的数据的神经网络.如: 时间序列数据(在时间轴上有规律地采样形成的一维网格): 图像数 ...

随机推荐

  1. 用java imageio调整图片DPI,例如从96调整为300

    因项目需求把图片的DPI值提升到300,否则OCR识别产生错乱:直接上源码:1.图片处理接口: package util.image.dpi; import java.awt.image.Buffer ...

  2. Jenkins-ssh远程执行nohup- java无法退出

    一,初步 #执行方式 ssh 192.168.2.103 " nohup java -jar /home/a/ipf/ight/feedback/ixxxedback-platform-1. ...

  3. create-react-app 构建的项目使用释放配置文件 webpack 等等 运行 npm run eject 报错

    使用 git 提交一次记录即可正常 git add . git commit -m 'init' npm run eject

  4. python实现RGB转换HSV

    def rgb2hsv(r, g, b):     r, g, b = r/255.0, g/255.0, b/255.0     mx = max(r, g, b)     mn = min(r,  ...

  5. Mysql中经常出现的乱码问题

    Mysql中执行SET NAMES utf8这条SQl的作用 1)首先,Mysql服务器的编码和数据库的编码在配置文件my.ini中设置: 用记事本打开配置文件,修改代码:default-charac ...

  6. php 的路由简介 (一个简单的路由模式)

    <?php $_SERVER['REQUEST_URI'] = '/post/edit/1024?foo=bar'; $uri = explode('/', parse_url($_SERVER ...

  7. scss 用法 及 es6 用法讲解

    scss 用法的准备工作,下载 考拉 编译工具 且目录的名字一定不能出现中文,哪里都不能出现中文,否则就会报错 es6 用法 let 和 const  let  声明变量的方式 在 {} 代码块里面才 ...

  8. jmeter正则表达式提取

    使用jmeter正则表达式提取器之前,首先 使用httpwatch 分析一下 我要要测试的系统正则管理的规则: 例如:我这里要关联的是一个ODS数据仓库平台的登录 1./sso/login..单点登录 ...

  9. linux交叉编译Windows版本的ffmpeg

    主要参考http://www.cnblogs.com/haibindev/archive/2011/12/01/2270126.html 在我的机器上编译libfaac的时候 出现问题了 输出如下 . ...

  10. React Native商城项目实战01 - 初始化设置

    1.创建项目 $ react-native init BuyDemo 2.导入图片资源 安卓:把文件夹放到/android/app/src/main/res/目录下,如图: iOS: Xcode打开工 ...