在逻辑回归中使用mnist数据集。导入相应的包以及数据集。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/', one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print ("MNIST loaded")

使用tensorflow中函数进行逻辑回归的构建。调用softmax函数进行逻辑回归模型的构建;构造损失函数【y*tf.log(actv)】;构造梯度下降训练器;

x = tf.placeholder("float", [None, 784]) #784代表照片像素为28*28 10代表共有十个数字
y = tf.placeholder("float", [None, 10]) # None is for infinite
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# LOGISTIC REGRESSION MODEL
#get the predict number
actv = tf.nn.softmax(tf.matmul(x, W) + b)
# COST FUNCTION
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv), reduction_indices=1))
# OPTIMIZER
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# PREDICTION
pred = tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))
# ACCURACY
accr = tf.reduce_mean(tf.cast(pred, "float"))
# INITIALIZER
init = tf.global_variables_initializer()

循环五十次,每五次打印一次结果,每次训练取100个样本

training_epochs = 50
batch_size = 100
display_step = 5
# SESSION
sess = tf.Session()
sess.run(init)
# MINI-BATCH LEARNING
for epoch in range(training_epochs):
avg_cost = 0.
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys})
feeds = {x: batch_xs, y: batch_ys}
avg_cost += sess.run(cost, feed_dict=feeds)/num_batch
# DISPLAY
if epoch % display_step == 0:
feeds_train = {x: batch_xs, y: batch_ys}
feeds_test = {x: mnist.test.images, y: mnist.test.labels}
train_acc = sess.run(accr, feed_dict=feeds_train)
test_acc = sess.run(accr, feed_dict=feeds_test)
print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))
print ("DONE")

tensorflow学习笔记五----------逻辑回归的更多相关文章

  1. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  2. Python学习笔记之逻辑回归

    # -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...

  3. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  4. 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识

    深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...

  5. tensorflow学习笔记(4)-学习率

    tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...

  6. tensorflow学习笔记(2)-反向传播

    tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...

  7. TensorFlow学习笔记——LeNet-5(训练自己的数据集)

    在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...

  8. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  9. TensorFlow学习笔记10-卷积网络

    卷积网络 卷积神经网络(Convolutional Neural Network,CNN)专门处理具有类似网格结构的数据的神经网络.如: 时间序列数据(在时间轴上有规律地采样形成的一维网格): 图像数 ...

随机推荐

  1. 为什么要重写hashcode( )和equals( )?

    打个比方,一个名叫张三的人去住酒店,在前台登记完名字就去了99层100号房间,此时警察来前台找叫张三的这个人住在哪间房,经过查询,该酒店住宿的有50个叫张三的,需要遍历查询,查询起来很不方便. 那么就 ...

  2. python3.5-tensorflow-keras 安装

    cpu centos FROM centos:7 MAINTAINER yon RUN yum -y install make wget \ && wget -O /etc/yum.r ...

  3. JPA学习(三、JPA_API)

    框架学习之JPA(三) JPA是Java Persistence API的简称,中文名Java持久层API,是JDK 5.0注解或XML描述对象-关系表的映射关系,并将运行期的实体对象持久化到数据库中 ...

  4. LCA模板 ( 最近公共祖先 )

    LCA 有几种经典的求取方法.这里只给出模板,至于原理我完全不懂. 1.RMQ转LCA.复杂度O(n+nlog2n+m) 大致就是 DFS求出欧拉序 => 对欧拉序做ST表 => LCA( ...

  5. 【技术分享:python 应用之一】如何使用 Python 对 Excel 做一份数据透视表

    客户这边,其中有一张如同上图所示的数据汇总表,然而需求是,需要将这张表数据做一个数据透视表,最后通过数据透视表中的数据,填写至系统数据库.拿到需求,首先就想到肯定不能直接用设计器去操作 Excel,通 ...

  6. Python3学习笔记(六):字符串

    一.基本字符串操作 所有标准的序列操作(索引.分片.乘法.判断成员资格.求长度.取最小值和最大值)对字符串同样适用.但是字符串是不可改变的. 二.字符串格式化 字符串格式化使用字符串格式化操作符(%) ...

  7. java学习之- 创建线程run和start特点

    标签(空格分隔): run,start 为什么做run方法的覆盖? 1.Thread类用于描述线程,该类就定义一个功能用于存储线程要运行的代码,该存储功能就是run方法: 也就是说Thread种的ru ...

  8. zay大爷的神仙题目 D1T1-大美江湖

    在前几天的时候,千古神犇zay(吊打zhx那个)出了一套神仙题目,所以我得来分析分析QWQ 先补个网易云链接QWQ 毕竟是T1嘛,还算是比较简单的,那道题,读完题目就发现是个中等模拟(猪国杀算大模拟的 ...

  9. leetcode 342. 4的幂(python)

    1. 题目描述 给定一个整数 (32 位有符号整数),请编写一个函数来判断它是否是 4 的幂次方. 示例 1: 输入: 16输出: true示例 2: 输入: 5输出: false 2. 思路 参考: ...

  10. Tomcat中出现"RFC 7230 and RFC 3986"错误的解决方法

    在用axios从前台向后台发请求时,后台报错 Invalid character found in the request target. The valid characters are defin ...