多层感知机中:

hi 以 p 的概率被丢弃,以 1-p 的概率被拉伸,除以  1 - p

import mxnet as mx
import sys
import os
import time
import gluonbook as gb
from mxnet import autograd,init
from mxnet import nd,gluon
from mxnet.gluon import data as gdata,nn
from mxnet.gluon import loss as gloss '''
# 模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784,10,256,256 W1 = nd.random.normal(scale=0.01,shape=(num_inputs,num_hiddens1))
b1 = nd.zeros(num_hiddens1) W2 = nd.random.normal(scale=0.01,shape=(num_hiddens1,num_hiddens2))
b2 = nd.zeros(num_hiddens2) W3 = nd.random.normal(scale=0.01,shape=(num_hiddens2,num_outputs))
b3 = nd.zeros(num_outputs) params = [W1,b1,W2,b2,W3,b3] for param in params:
param.attach_grad() # 定义网络 '''
# 读取数据
# fashionMNIST 28*28 转为224*224
def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(
'~', '.mxnet', 'datasets', 'fashion-mnist')):
root = os.path.expanduser(root) # 展开用户路径 '~'。
transformer = []
if resize:
transformer += [gdata.vision.transforms.Resize(resize)]
transformer += [gdata.vision.transforms.ToTensor()]
transformer = gdata.vision.transforms.Compose(transformer)
mnist_train = gdata.vision.FashionMNIST(root=root, train=True)
mnist_test = gdata.vision.FashionMNIST(root=root, train=False)
num_workers = 0 if sys.platform.startswith('win32') else 4
train_iter = gdata.DataLoader(
mnist_train.transform_first(transformer), batch_size, shuffle=True,
num_workers=num_workers)
test_iter = gdata.DataLoader(
mnist_test.transform_first(transformer), batch_size, shuffle=False,
num_workers=num_workers)
return train_iter, test_iter # 定义网络
drop_prob1,drop_prob2 = 0.2,0.5
# Gluon版
net = nn.Sequential()
net.add(nn.Dense(256,activation="relu"),
nn.Dropout(drop_prob1),
nn.Dense(256,activation="relu"),
nn.Dropout(drop_prob2),
nn.Dense(10)
)
net.initialize(init.Normal(sigma=0.01)) # 训练模型 def accuracy(y_hat, y):
return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
def evaluate_accuracy(data_iter, net):
acc = 0
for X, y in data_iter:
acc += accuracy(net(X), y)
return acc / len(data_iter) def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
params=None, lr=None, trainer=None):
for epoch in range(num_epochs):
train_l_sum = 0
train_acc_sum = 0
for X, y in train_iter:
with autograd.record():
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
if trainer is None:
gb.sgd(params, lr, batch_size)
else:
trainer.step(batch_size) # 下一节将用到。
train_l_sum += l.mean().asscalar()
train_acc_sum += accuracy(y_hat, y)
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / len(train_iter),
train_acc_sum / len(train_iter), test_acc)) num_epochs = 5
lr = 0.5
batch_size = 256
loss = gloss.SoftmaxCrossEntropyLoss()
train_iter, test_iter = load_data_fashion_mnist(batch_size) trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':lr})
train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

Gluon 实现 dropout 丢弃法的更多相关文章

  1. 【神经网络】丢弃法(dropout)

    丢弃法是一种降低过拟合的方法,具体过程是在神经网络传播的过程中,随机"沉默"一些节点.这个行为让模型过度贴合训练集的难度更高. 添加丢弃层后,训练速度明显上升,在同样的轮数下测试集 ...

  2. MXNET:丢弃法

    除了前面介绍的权重衰减以外,深度学习模型常常使用丢弃法(dropout)来应对过拟合问题. 方法与原理 为了确保测试模型的确定性,丢弃法的使用只发生在训练模型时,并非测试模型时.当神经网络中的某一层使 ...

  3. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  4. dropout——gluon

    https://blog.csdn.net/lizzy05/article/details/80162060 from mxnet import nd def dropout(X, drop_prob ...

  5. 动手学深度学习14- pytorch Dropout 实现与原理

    方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...

  6. 神经网络优化算法:Dropout、梯度消失/爆炸、Adam优化算法,一篇就够了!

    1. 训练误差和泛化误差 机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训练数据集上更准确时,它在测试数据集上却不⼀定更准确.这是为什么呢 ...

  7. 从头学pytorch(七):dropout防止过拟合

    上一篇讲了防止过拟合的一种方式,权重衰减,也即在loss上加上一部分\(\frac{\lambda}{2n} \|\boldsymbol{w}\|^2\),从而使得w不至于过大,即不过分偏向某个特征. ...

  8. dropout总结

    1.伯努利分布:伯努利分布亦称“零一分布”.“两点分布”.称随机变量X有伯努利分布, 参数为p(0<p<1),如果它分别以概率p和1-p取1和0为值.EX= p,DX=p(1-p). 2. ...

  9. mxnet(gluon)—— 模型、数据集、损失函数、优化子等类、接口大全

    1. 数据集 dataset_train = gluon.data.ArrayDataset(X_train, y_train) data_iter = gluon.data.DataLoader(d ...

随机推荐

  1. ASP.NET能知道的东西

    ASP.NET能知道的东西 获取服务器电脑名: Page.Server.ManchineName 获取用户信息: Page.User 获取客户端电脑名:Page.Request.UserHostNam ...

  2. DIY了一下自己blog的UI

    当年才学前端时就想改自己blog的UI,然鹅当时没看见那个“申请JS权限”,一直以为blog不能随意DIY样式,只改了少许CSS.现在重新看看设置管理选项,简单修改了一下样式(注意:修改样式之前发邮件 ...

  3. C#利用WinForm调用WebServices实现增删改查

    实习导师要求做一个项目,用Winform调用WebServices实现增删改查的功能.写下这篇博客,当做是这个项目的总结.如果您有什么建议,可以给我留言.欢迎指正. 1.首先,我接到这个项目的时候,根 ...

  4. CentOS7 一键安装KMS服务【整理】

    KMS,是 Key Management System 的缩写,也就是密钥管理系统.这里所说的 KMS,毋庸置疑就是用来激活 VOL 版本的 Windows 和 Office 的 KMS 啦.经常能在 ...

  5. log4j的简单使用

    引入jar包org.apache.log4j.Logger,项目src目录下建立一个log4j.properties配置文件 log4j.rootLogger=INFO,A1,R log4j.appe ...

  6. Iphone各个型号机型的详细参数,尺寸和dpr以及像素

    1.iPhone尺寸规格 2.单位inch(英吋) 1 inch = 2.54cm = 25.4mm 3.iPhone手机宽高 上表中的宽高(width/height)为手机的物理尺寸,包括显示屏和边 ...

  7. linux系统下php扩展的安装

    0. 这里以php安装redis扩展为例 1. 首先下载并解压redis扩展包 [root@xxx ~]# cd /usr/local/src [root@xxx src]# wget https:/ ...

  8. cf1043E. Mysterious Crime(二分 前缀和)

    题意 题目链接 Sol 考场上做完前四题的时候大概还剩半个小时吧,那时候已经困的不行了. 看了看E发现好像很可做?? 又仔细看了几眼发现这不是sb题么... 先考虑两个人,假设贡献分别为\((x, y ...

  9. Execution default-resources of goal org.apache.maven.plugins:maven-resources-plugin:2.6:resources failed: Unable to load the mojo 'resources' (or one of its required components)

    1.异常提示: Description Resource Path Location Type Execution default-resources of goal org.apache.maven ...

  10. oracle数据库基本命令

    数据库字符集: SQL> select * from nls_database_parameters where parameter='NLS_CHARACTERSET'; PARAMETER ...