tensorflow 学习笔记 多层感知机
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 9 19:20:51 2017 @author: Jarvis
"""
'''
tnesorflow 做机器学习的几个步骤
1.定义公式
2.定义loss function,选择优化器,并制定优化器优化loss
3.迭代地对数据进行训练
4。在测试集或验证集对准确率进行评测 ''' import tensorflow as tf
import pandas as pd
import random
#自己定义的一个选取batch进行训练的一个取batch函数
def next_batch(mnist, num,ilen = 55):
size = len(mnist)
selected_n = set([]) while(len(selected_n) < num):
t = random.choice(range(size))
selected_n.add(t)
l = list(selected_n) batch_xs = []
batch_ys = [] batch_xs = mnist.iloc[l,range(2,54)] batch_ys = mnist.iloc[l,range(54,62)]
return batch_xs,batch_ys #对数据进行读取
org_mnist = pd.read_csv("NDVI_NDWI.csv",header = None,encoding = 'gbk')
mnist = pd.get_dummies(org_mnist)
#创建session
#input_data.read_data_sets("MNIST_data/",one_hot = True)
sess = tf.InteractiveSession() #定义算法公式,在此处就是神经网络的结构方式
in_units = 52#每一条instance具有52个输入
h1_units = 30
h2_units = 20
h3_units = 10
h4_units = 5 #tf.truncated_normal是正态分布的一个东东,主要用于初始化一些W矩阵
W1 = tf.Variable(tf.truncated_normal([in_units,h1_units],stddev = 0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
W2 = tf.Variable(tf.zeros([h1_units,h2_units]))#[h1_units,8]
b2 = tf.Variable(tf.zeros([h2_units]))#
W3 = tf.Variable(tf.zeros([h2_units,h3_units]))
b3 = tf.Variable(tf.zeros([h3_units]))
W4 = tf.Variable(tf.zeros([h3_units,8]))
b4 = tf.Variable(tf.zeros([8])) '''
W4 = tf.Variable(tf.zeros([h3_units,h4_units]))
b4 = tf.Variable(tf.zeros([h4_units]))
W5 = tf.Variable(tf.zeros([h4_units,8]))
b5 = tf.Variable(tf.zeros([8]))
'''
x = tf.placeholder(tf.float32,[None, in_units])
keep_prob = tf.placeholder(tf.float32)#dropout 的比例 keep_prob hidden1 = tf.nn.sigmoid(tf.matmul(x,W1)+b1)
hidden1_drop = tf.nn.dropout(hidden1,keep_prob)
hidden2 = tf.nn.sigmoid(tf.matmul(hidden1_drop,W2)+b2)
hidden2_drop = tf.nn.dropout(hidden2,keep_prob)
hidden3 = tf.nn.sigmoid(tf.matmul(hidden2_drop,W3)+b3)
hidden3_drop = tf.nn.dropout(hidden3,keep_prob)
#hidden4 = tf.nn.sigmoid(tf.matmul(hidden3_drop,W4)+b4)
#hidden4_drop = tf.nn.dropout(hidden4,keep_prob) y = tf.nn.softmax(tf.matmul(hidden3_drop,W4)+b4)
y_ = tf.placeholder(tf.float32,[None,8])#[None,10]
#设置优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
train_step = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy) tf.global_variables_initializer().run() for i in range(2010):#
batch_xs, batch_ys = next_batch(mnist,1000)#1000 3
train_step.run( {x : batch_xs, y_ : batch_ys,keep_prob: 1}) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
batch_xs, batch_ys = next_batch(mnist,10000)
print(accuracy.eval({x:batch_xs,y_:batch_ys,keep_prob:1.0}))
tensorflow 学习笔记 多层感知机的更多相关文章
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记4-线性代数基础
TensorFlow学习笔记4-线性代数基础 本笔记内容为"AI深度学习".内容主要参考<Deep Learning>中文版. \(X\)表示训练集的设计矩阵,其大小为 ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——VGGNet
2014年,牛津大学计算机视觉组(Visual Geometry Group)和 Google DeepMind 公司的研究员一起研发了新的深度卷积神经网络:VGGNet ,并取得了ILSVRC201 ...
- Tensorflow学习笔记No.4.1
使用CNN卷积神经网络(1) 简单介绍CNN卷积神经网络的概念和原理. 已经了解的小伙伴可以跳转到Tensorflow学习笔记No.4.2学习如和用Tensorflow实现简单的卷积神经网络. 1.C ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
随机推荐
- ios自定义数字键盘
因为项目又一个提现的功能,textfiled文本框输入需要弹出数字键盘,首先想到的就是设置textfiled的keyboardType为numberPad,此时你会看到如下的效果: 但是很遗憾这样 ...
- 解决 Win10 UWP 无法使用 ss 连接
一旦使用了 ss, 那么很多应用就无法连接网络. 本文提供一个方法可以简单使用ss提供的代理. 多谢 wtwsgs 提供方法:http://blog.csdn.net/wtwsgs/article/d ...
- daterangepicker 使用方法以及各种小bug修复
双日历时间段选择插件 — daterangepicker是bootstrap框架后期的一个时间控件,可以设定多个时间段选项,也可以自定义时间段,由用户自己选择起始时间和终止时间,时间段的最大跨度可以在 ...
- vux 组件打造手机端项目
其实,我用vux组件的过程是这样的,哇!太方便了!!功能好全!!太简单了!!然后,就各种"跳坑".以下排坑环节. 1.安装vux:cnpm i -S vux; 比较顺利吧. 2 ...
- javascript方法的方法名慎用close
通常我们在定义了与window同名的方法时,会自动覆盖掉window同名的方法.close()方法也不例外.示例: <!DOCTYPE html PUBLIC "-//W3C//DTD ...
- Nodejs.安装.非源码方式安装Node.js (Centos)
已验证的适用环境: Centos6.x 树莓派官方ROM(Raspbian) 先去官网下载已编译好的安装包 https://nodejs.org/en/download/current/ 以Cent ...
- 移动端通过ajax上传图片(文件)并在前台展示——通过H5的FormData对象
前些时候遇到移动端需要上传图片和视频的问题,之前一直通过ajax异步的提交数据,所以在寻找通过ajax上传文件的方法.发现了H5里新增了一个FormData对象,通过这个对象可以直接绑定html中的f ...
- Appium python自动化测试系列之自动化截图(十一)
11.1 截图函数的正常使用 11.1.1 截图方法 无论是在手动测试还是自动化测试中场景复现永远是一个很重要的事情,有时候一些问题可能很难复现,这个都需要测试人员对bug有很高的敏感度,在一般的情况 ...
- angular内置provider之$compileProvider
一.方法概览 directive(name, directiveFactory) component(name, options) aHrefSanitizationWhitelist([regexp ...
- 转载——yum源的超级简单配置
1.先挂载光盘. 使用命令"mount -o loop /dev/sr0 /mnt/cdrom".如果使用命令"mount -o loop /dev/cdrom ...