import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

莫烦TensorFlow_10 过拟合的更多相关文章

  1. tensorflow 莫烦教程

    1,感谢莫烦 2,第一个实例:用tf拟合线性函数 import tensorflow as tf import numpy as np # create data x_data = np.random ...

  2. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ[' import numpy as ...

  3. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  4. 稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记。

    稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记. 还有 google 在 udacity 上的 CNN 教程. CNN(Convolutional Neural Networks) 卷积神经网络简单 ...

  5. 莫烦大大TensorFlow学习笔记(9)----可视化

      一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...

  6. scikit-learn学习笔记-bili莫烦

    bilibili莫烦scikit-learn视频学习笔记 1.使用KNN对iris数据分类 from sklearn import datasets from sklearn.model_select ...

  7. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  8. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  9. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

随机推荐

  1. layui中form表单渲染的问题

    layui 官网的这部分文档介绍:http://www.layui.com/doc/modules/form.html#render 注意:针对的是表单元素,input select  textare ...

  2. webapi使用ExceptionFilterAttribute过滤器

    文章 public class ApiExceptionFilterAttribute:ExceptionFilterAttribute { public override void OnExcept ...

  3. 树莓派4b+linux

    用Win32DiskImager烧录系统 先在boot根目录下新建ssh空文件夹来开启ssh功能,否则ssh是关闭的,用putty一直连不上,显示拒绝连接 1.联网: 初次 (实践证明:直接在sd卡根 ...

  4. BoW算法及DBoW2库简介

    由于在ORB-SLAM2中扩展图像识别模块,因此总结一下BoW算法,并对DBoW2库做简单介绍. 1. BoW算法 BoW算法即Bag of Words模型,是图像检索领域最常用的方法,也是基于内容的 ...

  5. A1093 Count PAT's (25 分)

    一.技术总结 这是一个逻辑题,题目大职意思是可以组成多少个PAT,可以以A为中心计算两边的P和T,然后数量乘积最后相加便是答案. 还有一个注意的是每次相加后记得mod,取余,不要等到最后加完再取余,会 ...

  6. iOS性能优化-数组、字典便利时间复杂

    上图是几种时间复杂度的关系,性能优化一定程度上是为了降低程序执行效率减低时间复杂度. 如下是几种时间复杂度的实例: O(1) return array[index] == value; 复制代码 O( ...

  7. Django学习笔记(19)——BBS+Blog项目开发(3)细节知识点补充

    本文将BBS+Blog项目开发中所需要的细节知识点进行补充,其中内容包括KindEditor编辑器的使用,BeautifulSoup 模块及其防XSS攻击,Django中admin管理工具的使用,me ...

  8. 如何取消 SqlDataAdapter.Fill() 的执行(转载)

    问 Scenario: We have a DataGridView which is attached to DataAdapter (datatable), we load the data in ...

  9. 在.NET Core 3.0 Preview上使用Windows窗体设计器

    支持使用基于Windows窗体应用程序的.NET Core 3.0(预览)的Windows窗体设计器 介绍 截至撰写本文时,Microsoft和社区目前正在测试.NET Core 3.0.如果您在.N ...

  10. Jenkins配置LDAP认证

    managerdn即为连接到AD的账号