optim.py cs231n
n如果有错误,欢迎指出,不胜感激
import numpy as np """
This file implements various first-order update rules that are commonly used for
training neural networks. Each update rule accepts current weights and the
gradient of the loss with respect to those weights and produces the next set of
weights. Each update rule has the same interface: def update(w, dw, config=None): Inputs:
- w: A numpy array giving the current weights.
- dw: A numpy array of the same shape as w giving the gradient of the
loss with respect to w.
- config: A dictionary containing hyperparameter values such as learning rate,
momentum, etc. If the update rule requires caching values over many
iterations, then config will also hold these cached values. Returns:
- next_w: The next point after the update.
- config: The config dictionary to be passed to the next iteration of the
update rule. NOTE: For most update rules, the default learning rate will probably not perform
well; however the default values of the other hyperparameters should work well
for a variety of different problems. For efficiency, update rules may perform in-place updates, mutating w and
setting next_w equal to w.
""" def sgd(w, dw, config=None):
"""
Performs vanilla stochastic gradient descent. config format:
- learning_rate: Scalar learning rate.
"""
if config is None: config = {}
config.setdefault('learning_rate', 1e-2)
w -= config['learning_rate'] * dw
return w, config def sgd_momentum(w, dw, config=None):
"""
Performs stochastic gradient descent with momentum. config format:
- learning_rate: Scalar learning rate.
- momentum: Scalar between 0 and 1 giving the momentum value.
Setting momentum = 0 reduces to sgd.
- velocity: A numpy array of the same shape as w and dw used to store a moving
average of the gradients.
"""
if config is None: config = {}
config.setdefault('learning_rate', 1e-2)
config.setdefault('momentum', 0.9)
v = config.get('velocity', np.zeros_like(w)) next_w = None
v=v*config['momentum']-config['learning_rate']*dw
next_w=w+v
config['velocity'] = v return next_w, config def rmsprop(x, dx, config=None):
"""
Uses the RMSProp update rule, which uses a moving average of squared gradient
values to set adaptive per-parameter learning rates. config format:
- learning_rate: Scalar learning rate.
- decay_rate: Scalar between 0 and 1 giving the decay rate for the squared
gradient cache.
- epsilon: Small scalar used for smoothing to avoid dividing by zero.
- cache: Moving average of second moments of gradients.
"""
if config is None: config = {}
config.setdefault('learning_rate', 1e-2)
config.setdefault('decay_rate', 0.99)
config.setdefault('epsilon', 1e-8)
config.setdefault('cache', np.zeros_like(x)) next_x = None cache=config['cache']*config['decay_rate']+(1-config['decay_rate'])*dx**2
next_x=x-config['learning_rate']*dx/np.sqrt(cache+config['epsilon'])
config['cache']=cache return next_x, config def adam(x, dx, config=None):
"""
Uses the Adam update rule, which incorporates moving averages of both the
gradient and its square and a bias correction term. config format:
- learning_rate: Scalar learning rate.
- beta1: Decay rate for moving average of first moment of gradient.
- beta2: Decay rate for moving average of second moment of gradient.
- epsilon: Small scalar used for smoothing to avoid dividing by zero.
- m: Moving average of gradient.
- v: Moving average of squared gradient.
- t: Iteration number.
"""
if config is None: config = {}
config.setdefault('learning_rate', 1e-3)
config.setdefault('beta1', 0.9)
config.setdefault('beta2', 0.999)
config.setdefault('epsilon', 1e-8)
config.setdefault('m', np.zeros_like(x))
config.setdefault('v', np.zeros_like(x))
config.setdefault('t', 0)
config['t']+=1
这个方法比较综合,各种方法的好处吧
m=config['beta1']*config['m']+(1-config['beta1'])*dx # now to change by acc
v=config['beta2']*config['v']+(1-config['beta2'])*dx**2
config['m']=m
config['v']=v
m=m/(1-config['beta1']**config['t'])
v=v/(1-config['beta2']**config['t']) next_x=x-config['learning_rate']*m/np.sqrt(v+config['epsilon']) return next_x, config
n
optim.py cs231n的更多相关文章
- cnn.py cs231n
n import numpy as np from cs231n.layers import * from cs231n.fast_layers import * from cs231n.layer_ ...
- fc_net.py cs231n
n如果有错误,欢迎指出,不胜感激 import numpy as np from cs231n.layers import * from cs231n.layer_utils import * cla ...
- layers.py cs231n
如果有错误,欢迎指出,不胜感激. import numpy as np def affine_forward(x, w, b): 第一个最简单的 affine_forward简单的前向传递,返回 ou ...
- 笔记:CS231n+assignment2(作业二)(一)
第二个作业难度很高,但做(抄)完之后收获还是很大的.... 一.Fully-Connected Neural Nets 首先是对之前的神经网络的程序进行重构,目的是可以构建任意大小的全连接的neura ...
- 深度学习原理与框架-神经网络-cifar10分类(代码) 1.np.concatenate(进行数据串接) 2.np.hstack(将数据横着排列) 3.hasattr(判断.py文件的函数是否存在) 4.reshape(维度重构) 5.tanspose(维度位置变化) 6.pickle.load(f文件读入) 7.np.argmax(获得最大值索引) 8.np.maximum(阈值比较)
横1. np.concatenate(list, axis=0) 将数据进行串接,这里主要是可以将列表进行x轴获得y轴的串接 参数说明:list表示需要串接的列表,axis=0,表示从上到下进行串接 ...
- optim.py-使用tensorflow实现一般优化算法
optim.py Project URL:https://github.com/Codsir/optim.git Based on: tensorflow, numpy, copy, inspect ...
- 深度学习之卷积神经网络(CNN)详解与代码实现(一)
卷积神经网络(CNN)详解与代码实现 本文系作者原创,转载请注明出处:https://www.cnblogs.com/further-further-further/p/10430073.html 目 ...
- 深度学习原理与框架-卷积神经网络-cifar10分类(图片分类代码) 1.数据读入 2.模型构建 3.模型参数训练
卷积神经网络:下面要说的这个网络,由下面三层所组成 卷积网络:卷积层 + 激活层relu+ 池化层max_pool组成 神经网络:线性变化 + 激活层relu 神经网络: 线性变化(获得得分值) 代码 ...
- Pytorch1.3源码解析-第一篇
pytorch$ tree -L 1 . ├── android ├── aten ├── benchmarks ├── binaries ├── c10 ├── caffe2 ├── CITATIO ...
随机推荐
- CodeForces 232C Doe Graphs(分治+搜索)
CF232C Doe Graphs 题意 题意翻译 \(Doe\)以她自己的名字来命名下面的无向图 \(D(0)\)是只有一个编号为\(1\)的结点的图. \(D(1)\)是只有两个编号分别为\(1\ ...
- Spring_关于@Resource注入为null解决办法
初学spring,我在dao层初始化c3p0的时候,使用@Resource注解新建对象是发现注入为null,告诉我 java.lang.NullPointerException. @Repositor ...
- day49作业
结合前端,django,MySQL,pymysql模块实现数据库数据动态展示到前端 效果图: 数据交互流程 urls.py代码: from django.conf.urls import url fr ...
- pyd打包补充
网上说的将python代码,通过Cython打包成pyd的教程挺多,好处也多,主要有两个: 1.隐藏代码 2.加速运行速度 补充两点: 1.打包脚本配置 __build__.py from distu ...
- mysql innodb 的 逻辑存储结构
如上图: innodb 的 逻辑存储单元分成 表空间,段,区,页 4个等级 默认情况下,一个数据库 所有变共享一个 默认的表空间(tablespan).可以指定每个表一个表空间. 一个表空间管理着 多 ...
- java swing+socket实现多人聊天程序
swing+socket实现多人聊天程序 1.准备工作 先看效果: 客户端项目结构图: 服务端项目结构图: 2.运行原理 服务端 先开一个线程serverListerner,线程中开启一个Server ...
- jquery的each()遍历和ajax传值
页面展示 JS代码部分 /*功能:删除选中用户信息数据*/ function delUser(){ $("#delU").click(function(){ var unoStr ...
- opencv4 java投影
工程下载 https://download.csdn.net/download/qq_16596909/11505994 比较适合与验证码的处理,毕竟八邻域降噪不能消除比较大的噪点,为了尽量减少噪点对 ...
- opencv读取的彩色图像,数据是GBR而不是RGB
开发久了,容易想当然 直到数据怎么也不对的时候,才想起来查一下手册 三个像素,当然没有这么大的像素,这是放大之后的 数据输出
- day67test
作业 1.按照上方 知识点总结 模块,总结今天所学知识点: 2.有以下广告数据(实际数据命名可以略做调整) ad_data = { tv: [ {img: 'img/tv/001.png', titl ...