import tensorflow as tf
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import LabelBinarizer #load data
digits = load_digits()
X = digits.data
y = digits.target
y = LabelBinarizer().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3) #
# add layer
#
def add_layer(inputs, in_size, out_size, n_layer, activation_function = None):
layer_name = 'layer%s' % n_layer Weights = tf.Variable(tf.random_normal([in_size, out_size]), name='W') # hang lie
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = 'b') Wx_plus_b = tf.matmul(inputs, Weights) + biases
Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob) # if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b) tf.summary.histogram(layer_name + '/outputs', outputs)
return outputs #
# define placeholder for inputs to network
#
keep_prob = tf.placeholder(tf.float32) #
xs = tf.placeholder(tf.float32, [None, 64]) # 8x8
ys = tf.placeholder(tf.float32, [None, 10]) #
# add output layer
#
l1 = add_layer(xs, 64, 50, 'l1', activation_function = tf.nn.tanh)
prediction = add_layer(l1, 50, 10, 'l2', activation_function = tf.nn.softmax) #
# the error between prediction and real data
#
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) #loss
tf.summary.scalar('loss', cross_entropy)
train_step = tf.train.GradientDescentOptimizer(0.6).minimize(cross_entropy) sess = tf.Session()
merged = tf.summary.merge_all() #summary writer goes here
train_writer = tf.summary.FileWriter("logs/train", sess.graph)
test_writer = tf.summary.FileWriter("logs/test", sess.graph) sess.run(tf.global_variables_initializer()) for i in range(500):
#sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:1.0}) # overfitted
sess.run(train_step, feed_dict={xs:X_train, ys:y_train, keep_prob:0.5}) # keep 0.5, drop 0.5
if i% 50 == 0:
#record loss
train_result = sess.run(merged, feed_dict={xs:X_train, ys:y_train, keep_prob:1})
test_result = sess.run(merged, feed_dict={xs:X_test, ys:y_test, keep_prob:1})
train_writer.add_summary(train_result, i)
test_writer.add_summary(test_result, i)

  

莫烦TensorFlow_10 过拟合的更多相关文章

  1. tensorflow 莫烦教程

    1,感谢莫烦 2,第一个实例:用tf拟合线性函数 import tensorflow as tf import numpy as np # create data x_data = np.random ...

  2. tensorflow学习笔记-bili莫烦

    bilibili莫烦tensorflow视频教程学习笔记 1.初次使用Tensorflow实现一元线性回归 # 屏蔽警告 import os os.environ[' import numpy as ...

  3. 【莫烦Pytorch】【P1】人工神经网络VS. 生物神经网络

    滴:转载引用请注明哦[握爪] https://www.cnblogs.com/zyrb/p/9700343.html 莫烦教程是一个免费的机器学习(不限于)的学习教程,幽默风俗的语言让我们这些刚刚起步 ...

  4. 稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记。

    稍稍乱入的CNN,本文依然是学习周莫烦视频的笔记. 还有 google 在 udacity 上的 CNN 教程. CNN(Convolutional Neural Networks) 卷积神经网络简单 ...

  5. 莫烦大大TensorFlow学习笔记(9)----可视化

      一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...

  6. scikit-learn学习笔记-bili莫烦

    bilibili莫烦scikit-learn视频学习笔记 1.使用KNN对iris数据分类 from sklearn import datasets from sklearn.model_select ...

  7. 莫烦pytorch学习笔记(八)——卷积神经网络(手写数字识别实现)

    莫烦视频网址 这个代码实现了预测和可视化 import os # third-party library import torch import torch.nn as nn import torch ...

  8. 莫烦pytorch学习笔记(七)——Optimizer优化器

    各种优化器的比较 莫烦的对各种优化通俗理解的视频 import torch import torch.utils.data as Data import torch.nn.functional as ...

  9. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

随机推荐

  1. Python调用C的DLL(动态链接库)

    开发环境:mingw64位,python3.6 64位 参考博客: mingw编译dll: https://blog.csdn.net/liyuanbhu/article/details/426123 ...

  2. centos 安装python3.7

    先安装依赖包: yum -y install bzip2 bzip2-devel ncurses openssl openssl-devel openssl-static xz lzma xz-dev ...

  3. ImportError: cannot import name 'render_to_response' 解决方法

    前几天 Django 官方推出了 3.0 框架,项目在 K8S 内部署启动的时候,报了这个错:ImportError: cannot import name 'render_to_response' ...

  4. C++:class

    class 类是C++的一个重要概念,也是面向对象的一个重要内容.类的行为类似结构体,但功能比结构体的更强大.类是定义该类对象的一个模板,它告诉我们,一个类应该具有什么内容. 声明.定义 类用关键字c ...

  5. HashMap的底层原理(jdk1.7.0_79)

    前言 在Java中我们最常用的集合类毫无疑问就是Map,其中HashMap作为Map最重要的实现类在我们代码中出现的评率也是很高的. 我们对HashMap最常用的操作就是put和get了,那么你知道它 ...

  6. Knative 实践:从源代码到服务的自动化部署

    通过之前的文章,相信大家已经熟悉了 Serving.Eventing 以及 Tekton.那么在实际使用中,我们往往会遇到一些复杂的场景,这时候就需要各个组件之间进行协作处理.例如我们提交源代码之后是 ...

  7. Kubernetes Job与CronJob(离线业务)

    Kubernetes Job与CronJob(离线业务) Job Job分为普通任务(Job)  一次性执行 应用场景:离线数据处理,视频解码等业务 官方文档:https://kubernetes.i ...

  8. kmv 学习笔记 工具

    qemu:kmv的文本管理工具,包括qemu-kvm.qemu-img libvirt:是一套免费.开源的支持Linux下主流虚拟化工具的C函数库,libvirtd是运行的守护进程的名称.包括GUI: ...

  9. WPF xmlns:i="http://schemas.microsoft.com/expression/2010/interactivity"

    <Button Grid.Row="1" Content="Load Data" BorderBrush="Black" Border ...

  10. Entity Framework 导航属性(2)

    1.学校 [Table("School")] public partial class School { public School() { Students = new List ...