import tensorflow as tf
from sklearn.datasets import load_digits
#from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer # load data
digits = load_digits() X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) def add_layer(inputs, in_size, out_size, layer_name, activation_function=None, ):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )
Wx_plus_b = tf.matmul(inputs, Weights) + biases
# here to dropout
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b, )
tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs # define placeholder for inputs to network
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32, [None, 64])
ys = tf.placeholder(tf.float32, [None, 10]) # add output layer
l1 = add_layer(xs, 64, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function=tf.nn.softmax) # the loss between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
tf.summary.scalar ('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all()
# summary writer goes in here
train_writer = tf.summary.FileWriter("logs4/train", sess.graph)
test_writer = tf.summary.FileWriter("logs4/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
# here to determine the keeping probability
sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})
if i % 50 == 0:
# record loss
train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})
test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线—Jason niu的更多相关文章

  1. TF之NN:matplotlib动态演示深度学习之tensorflow将神经网络系统自动学习并优化修正并且将输出结果可视化—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

  2. TF:TF分类问题之MNIST手写50000数据集实现87.4%准确率识别:SGD法+softmax法+cross_entropy法—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # number 1 to 10 ...

  3. 利用Linux自带的logrotate管理日志

    日常运维中,经常要对各类日志进行管理,清理,监控,尤其是因为应用bug,在1小时内就能写几十个G日志,导致磁盘爆满,系统挂掉. nohup.out,access.log,catalina.out 本文 ...

  4. GA:利用GA对一元函数进行优化过程,求x∈(0,10)中y的最大值——Jason niu

    x = 0:0.01:10; y = x + 10*sin(5*x)+7*cos(4*x); figure plot(x, y) xlabel('independent variable') ylab ...

  5. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  6. 『Sklearn』框架自带数据集接口

    自带数据集类型如下: # 自带小型数据集# sklearn.datasets.load_<name># 在线下载数据集# sklearn.datasets.fetch_<name&g ...

  7. 利用Sklearn实现加州房产价格预测,学习运用机器学习的整个流程(包含很多细节注解)

    Chapter1_housing_price_predict .caret, .dropup > .btn > .caret { border-top-color: #000 !impor ...

  8. 利用jdk自带的运行监控工具JConsole观察分析Java程序的运行

    利用jdk自带的运行监控工具JConsole观察分析Java程序的运行 原文链接 一.JConsole是什么 从Java 5开始 引入了 JConsole.JConsole 是一个内置 Java 性能 ...

  9. 利用sklearn计算文本相似性

    利用sklearn计算文本相似性,并将文本之间的相似度矩阵保存到文件当中.这里提取文本TF-IDF特征值进行文本的相似性计算. #!/usr/bin/python # -*- coding: utf- ...

随机推荐

  1. django 中自带的加密方法

    导入django 自带的加密算法 和flask中的哈希加密有一曲同工之妙.        from django.contrib.auth.hashers import make_password, ...

  2. css样式之补充。。。

    css常用的一些属性: 1.去掉下划线 :text-decoration:none ;2.加上下划线: text-decoration: underline; 3.调整文本和图片的位置(也就是设置元素 ...

  3. Laravel 中通过自定义分页器分页方法实现伪静态分页链接以利于 SEO

    我们知道,Laravel 自带的分页器方法包含 simplePaginate 和 paginate 方法,一个返回不带页码的分页链接,另一个返回带页码的分页链接,但是这两种分页链接页码都是以带问号的动 ...

  4. Unity3D料槽设备旋转(一)

    1.使用C#创建控制游戏对象的的脚本语言, 第一步: 在project师徒中create 一个C#脚本,将其按照自己的设备名称进行命名,这里我将其简单的命名成zhuaquanzhou.cs 使用编辑器 ...

  5. SPY

    问题 : SPY 时间限制: 1 Sec  内存限制: 128 MB 题目描述 The National Intelligence Council of X Nation receives a pie ...

  6. vue @click 使用三目运算(实现动态更换绑定的函数)

    转载:https://www.jianshu.com/p/ea4471c9f333 @click 错误写法 @click="dialogStatus=='create'?createData ...

  7. Spring Boot如何使用Runner实现启动时调用?用法和原理都在这里

    在日常的项目开发中经常会遇到这样的需求:项目启动的时候进行一些一次性的初始化工作,如读取加载资源文件.或者执行其它外部程序. 这个时候我们就可以用到spring-boot为我们提供的一种扩展机制--R ...

  8. Python操作MySQL案例

    最近都在学习Python代码,希望学会Python后,能给我带来更高的工作效率,所以每天坚持学习和拷代码,下面是一个Python操作MySQL的一个实例,该实例可以让更多的人更好了解MySQLdb模块 ...

  9. views.py视图函

    views.py视图函数来自 urls 的映射关系 常用所需模块 from django.shortcuts import render # ****** 渲染 render 跳转到指定的 url.h ...

  10. SQLServer索引及统计信息

    索引除了提高性能,还能维护数据库. 索引是一种存储结构,主要以B-Tree形式存储信息. B-Tree的定义: 1.每个节点最多只有m个节点(m>=2) 2.除了根节点和叶子节点外的每个节点上最 ...