import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

莫烦TensorFlow_10 过拟合的更多相关文章

  1. tensorflow 莫烦教程

    1,感谢莫烦 2,第一个实例:用tf拟合线性函数 import tensorflow as tf import numpy as np # create data x_data = np.random ...

  2. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ[' import numpy as ...

  3. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  4. 稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记。

    稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记. 还有 google 在 udacity 上的 CNN 教程. CNN(Convolutional Neural Networks) 卷积神经网络简单 ...

  5. 莫烦大大TensorFlow学习笔记(9)----可视化

      一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...

  6. scikit-learn学习笔记-bili莫烦

    bilibili莫烦scikit-learn视频学习笔记 1.使用KNN对iris数据分类 from sklearn import datasets from sklearn.model_select ...

  7. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  8. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  9. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

随机推荐

  1. 不获取元素,直接使用id操作dom元素

    今天无意中发现个让我很吃惊的问题. 不使用getElementById方法,也可以用id直接操作有id的元素. 继续搜索后,发现name也可以直接操作... 这让我大感意外,了解以后,忍不住写点东西记 ...

  2. 剑指Offer-9.变态跳台阶(C++/Java)

    题目: 一只青蛙一次可以跳上1级台阶,也可以跳上2级……它也可以跳上n级.求该青蛙跳上一个n级的台阶总共有多少种跳法. 分析: 假设我们要求跳上第3级的跳法,可以从第0级跳3级台阶到达,也可以从第1级 ...

  3. 【BZOJ3876】[AHOI2014&JSOI2014] 支线剧情(无源汇有上下界网络流)

    点此看题面 大致题意: 有一张\(DAG\),经过每条边有一定时间,从\(1\)号点出发,随时可以返回\(1\)号点,求经过所有边的最短时间. 无源汇有上下界网络流 这是无源汇有上下界网络流的板子题. ...

  4. WordPress更改“固定链接”后 ,页面出现404的解决方法

    一.Web服务器对应的是Nginx 解决方案:修改linux服务器下Nginx的配置文件,目录为:/usr/local/nginx/conf/nginx.conf, 也可以直接使用命令nginx -t ...

  5. ORB-SLAM2初步(Tracking.cpp)

    今天主要是分析一下Tracking.cpp这个文件,它是实现跟踪过程的主要文件,这里主要针对单目,并且只是截取了部分代码片段. 一.跟踪过程分析 首先构造函数中使用初始化列表对跟踪状态mState(N ...

  6. x3

    #include<stdio.h> int main() { char ch; printf("输入一个字符:\n"); scanf("%c",&a ...

  7. 内网Metasploit映射到外网

    下载frp Github项目地址:https://github.com/fatedier/frp 找到最新的releases下载,系统版本自行确认. 下载方法: wget https://github ...

  8. QAxBase: Error calling IDispatch member LineStyle: Unknown error

    word/Excel版本2007.2010.  wps也适用. //borders->dynamicCall("SetLineStyle(int,int,int)", 0, ...

  9. 记录战斗记录你,详解妖尾战斗录像系统[Unity]

    .katex { display: block; text-align: center; white-space: nowrap; } .katex-display > .katex > ...

  10. vscode10个必装的插件

    VSCode 必装的 10 个高效开发插件   本文介绍了目前前端开发最受欢迎的开发工具 VSCode 必装的 10 个开发插件,用于大大提高软件开发的效率. VSCode 的基本使用可以参考我的原创 ...