R语言 Keras Training Flags
在需要经常进行调参的情况下,可以使用 Training Flags 来快速变换参数,比起直接修改模型参数来得快而且不易出错。
https://tensorflow.rstudio.com/tools/training_flags.html
使用 flags()
library(keras)
FLAGS <- flags(
flag_integer("dense_units1", 128),
flag_numeric("dropout1", 0.4),
flag_integer("dense_units2", 128),
flag_numeric("dropout2", 0.3),
flag_integer("epochs", 30),
flag_integer("batch_size", 128),
flag_numeric("learning_rate", 0.001)
)
input <- layer_input(shape = c(784))
predictions <- input %>%
layer_dense(units = FLAGS$dense_units1, activation = 'relu') %>%
layer_dropout(rate = FLAGS$dropout1) %>%
layer_dense(units = FLAGS$dense_units2, activation = 'relu') %>%
layer_dropout(rate = FLAGS$dropout2) %>%
layer_dense(units = 10, activation = 'softmax')
model <- keras_model(input, predictions) %>% compile(
loss = 'categorical_crossentropy',
optimizer = optimizer_rmsprop(lr = FLAGS$learning_rate),
metrics = c('accuracy')
)
history <- model %>% fit(
x_train, y_train,
batch_size = FLAGS$batch_size,
epochs = FLAGS$epochs,
verbose = 1,
validation_split = 0.2
)
flags()是 keras 库的函数,不是R语言本身的函数。
使用YAML文件
flags()可以搭配YAML文件使用。按照官方教程,以为是把参数定义在YAML文件里,然后使用flags(file="flags.yml")直接读入。但是发现这样行不通,flags(file="flags.yml")得到的是一个空list。后来发现可能得这样使用才是正确的:
FLAGS <- flags(file = "flags.yml",
flag_integer("dense_units1", 128, "Dense units in first layer"),
flag_numeric("dropout1", 0.4, "Dropout after first layer"),
flag_integer("epochs", 30, "Number of epochs to train for")
)
flags.yml 中的参数优先,会覆盖掉flags()里的定义,也就是说,如果 flags.yml 里面是这样定义的:
dense_units1: 256
dropout1: 0.4
epochs: 30
那么,dense_units1这个参数的值是 256,而不是 128。
下面这种用法不正确,
FLAGS <- flags(file = "flags.yml",
)
会得到一个空list。可以认为,flags.yml其实是用来覆盖或者说修改flags()里面已有的参数定义。
R语言 Keras Training Flags的更多相关文章
- 如何在R语言中使用Logistic回归模型
在日常学习或工作中经常会使用线性回归模型对某一事物进行预测,例如预测房价.身高.GDP.学生成绩等,发现这些被预测的变量都属于连续型变量.然而有些情况下,被预测变量可能是二元变量,即成功或失败.流失或 ...
- R语言 推荐算法 recommenderlab包
recommend li_volleyball 2016年3月20日 library(recommenderlab) library(ggplot2) # data(MovieLense) dim(M ...
- R语言机器学习之caret包运用
在大数据如火如荼的时候,机器学习无疑成为了炙手可热的工具,机器学习是计算机科学和统计学的交叉学科, 旨在通过收集和分析数据的基础上,建立一系列的算法,模型对实际问题进行预测或分类. R语言无疑为我们提 ...
- R语言进行机器学习方法及实例(一)
版权声明:本文为博主原创文章,转载请注明出处 机器学习的研究领域是发明计算机算法,把数据转变为智能行为.机器学习和数据挖掘的区别可能是机器学习侧重于执行一个已知的任务,而数据发掘是在大数据中寻找有 ...
- 重磅︱文本挖掘深度学习之word2vec的R语言实现
每每以为攀得众山小,可.每每又切实来到起点,大牛们,缓缓脚步来俺笔记葩分享一下吧,please~ --------------------------- 笔者寄语:2013年末,Google发布的 w ...
- R语言︱XGBoost极端梯度上升以及forecastxgb(预测)+xgboost(回归)双案例解读
XGBoost不仅仅可以用来做分类还可以做时间序列方面的预测,而且已经有人做的很好,可以见最后的案例. 应用一:XGBoost用来做预测 ------------------------------- ...
- R+openNLP︱openNLP的六大可实现功能及其在R语言中的应用
每每以为攀得众山小,可.每每又切实来到起点,大牛们,缓缓脚步来俺笔记葩分享一下吧,please~ --------------------------- openNLP是NLP中比较好的开源工具,R语 ...
- R语言︱H2o深度学习的一些R语言实践——H2o包
每每以为攀得众山小,可.每每又切实来到起点,大牛们,缓缓脚步来俺笔记葩分享一下吧,please~ --------------------------- R语言H2o包的几个应用案例 笔者寄语:受启发 ...
- 碎片︱R语言与深度学习
笔者:受alphago影响,想看看深度学习,但是其在R语言中的应用包可谓少之又少,更多的是在matlab和python中或者是调用.整理一下目前我看到的R语言的材料: ---------------- ...
随机推荐
- CentOS和Windows互相远程桌面方法
https://blog.csdn.net/libaineu2004/article/details/49407883
- java valueOf
valueOf 方法可以将原生数值类型转化为对应的Number类型,java.lang.Number 基类包括ouble.Float.Byte.Short.Integer 以及 Long派生类, 也可 ...
- 二十四、python中sys模块
'''1.sys.argv:命令行参数List,第一个元素是程序本身路径''' import sys print (sys.argv)-------------------------------[' ...
- LeetCode 10——正则表达式匹配
1. 题目 2. 解答 在 回溯算法 中我们介绍了一种递归的思路来求解这个问题. 此外,这个问题也可以用动态规划的思路来解决.我们定义状态 \(P[i][j]\) 为子串 \(s[0, i)\) 和 ...
- c# WPF——创建带有图标的TreeView
1.使用数据模板对TreeViewItem进行更改 2.xaml中重写TreeviewItem的控件模板 3.继承TreeViewItem(TreeView中的元素),后台进行控件重写.(介绍此方法) ...
- Learn Python the hard way, ex42 物以类聚
依然少打很多剧情,并修改了很多,还好,能运行 #!urs/bin/python #coding:utf-8 from sys import exit from random import randin ...
- 【ABAP系列】SAP 关于出口(user-exit)MV50AFZ1的一些问题
公众号:SAP Technical 本文作者:matinal 原文出处:http://www.cnblogs.com/SAPmatinal/ 原文链接:[ABAP系列]SAP 关于出口(user-ex ...
- OuterXml和InnerXml
例如 <bkk> <rp fe="few" > <fe>fff</fe> </rp> </bkk> 对于fe ...
- 前端 CSS的选择器 属性选择器
属性选择器,字面意思就是根据标签中的属性,选中当前的标签. 属性选择器 通常在表单控件中 使用比较多 根据属性查找 /*用于选取带有指定属性的元素.*/ <!DOCTYPE html> & ...
- [Python3 填坑] 016 对 __getattr__ 和 __setattr__ 举例
目录 1. print( 坑的信息 ) 2. 开始填坑 2.1 __getattr__ 2.2 __setattr__ 1. print( 坑的信息 ) 挖坑时间:2019/04/07 明细 坑的编码 ...