用 Java 训练出一只“不死鸟”

作者:Kingyu & Lanking
FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使用 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何用 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。
增强学习(RL)的架构
在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。
CNN 训练简述
CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。
训练数据
在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。
public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}
训练的三个周期
训练分为 3 个不同的周期以更好地生成训练数据:
- Observe(观察) 周期:随机产生训练数据
- Explore (探索) 周期:随机与推理动作结合更新训练数据
- Training (训练) 周期:推理动作主导产生新数据
通过这种训练模式,我们可以更好的达到预期效果。
处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。
public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}
训练逻辑
首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:
NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));
postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:
// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。
卷积神经网络模型(CNN)
我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。
| layer | input shape | output shape |
|---|---|---|
| conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
| conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
| conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
| linear | (batchSize, 3136) | (batchSize, 512) |
| linear | (batchSize, 512) | (batchSize, 2) |
训练过程
DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。
- 实现 RlAgent 接口即可构建一个可以进行训练的智能体。
- 在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
- 创建 ReplayBuffer 可以存储并动态更新训练数据。
在实现这些接口后,只需要调用 step 方法:
RlEnv.step(action, training);
这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。
public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}
我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。
在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。
神经网络输入预处理
首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。
public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}
然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。
public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}
一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。
List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}
总结
这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

用 Java 训练出一只“不死鸟”的更多相关文章
- 137 Single Number II 数组中除了一个数外,其他的数都出现了三次,找出这个只出现一次的数
给定一个整型数组,除了一个元素只出现一次外,其余每个元素都出现了三次.求出那个只出现一次的数.注意:你的算法应该具有线性的时间复杂度.你能否不使用额外的内存来实现?详见:https://leetcod ...
- java.sql.SQLException: 对只转发结果集的无效操作: last
出错代码如下:static String u = "user";static String p = "psw";static String url = &quo ...
- 在Java中弹出位于其他类的由WindowsBuilder创建的JFrameApplicationWIndow
我们一般在使用Java弹出窗体的时候,一般是使用Jdialog这个所谓的"对话框类".但是,如果你不是初学者或研究员,而会在使用Java进行swing项目的开发,那么你很可能用到一 ...
- 给出2n+1个数,其中有2n个数出现过两次,如何用最简便的方法找出里面只出现了一次的那个数(转载)
有2n+1个数,其中有2n个数出现过两次,找出其中只出现一次的数 例如这样一组数3,3,1,2,4,2,5,5,4,其中只有1出现了1次,其他都是出现了2次,如何找出其中的1? 最简便的方法是使用异或 ...
- IE6多出一只猪的经典bug
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...
- java 弹出选择目录框(选择文件夹),获取选择的文件夹路径
java 弹出选择目录框(选择文件夹),获取选择的文件夹路径 java 弹出选择目录框(选择文件夹),获取选择的文件夹路径:int result = 0;File file = null;String ...
- 2017年11月GitHub上最热门的Java项目出炉
2017年11月GitHub上最热门的Java项目出炉~ 一起来看看这些项目你使用过哪些呢? 1分布式 RPC 服务框架 dubbohttps://github.com/alibaba/dubbo S ...
- Java释出的时候,AWT作为Java最弱的组件受到不小的批评
Java释出的时候,AWT作为Java最弱的组件受到不小的批评. 最根本的缺点是AWT在原生的用户界面之上仅提供了一个非常薄的抽象层. 例如,生成一个AWT的 复选框会导致AWT直接调用下层原生例程来 ...
- 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器
迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...
随机推荐
- elasticsearch 使用同义词
elasticsearch 使用同义词 使用环境 elasticsearch5.1.1 kibana5.1.1 同义词插件5.1.1 安装插件 下载对应的elasticsearch-analysis- ...
- Java反射说得透彻一些
目录 一.反射机制是什么? 二.反射的具体使用 2.1 获取对象的包名以及类名 2.2 获取Class对象 2.3 getInstance()获取指定类型的实例化对象 2.4 通过构造函数对象实例化对 ...
- transient关键字的作用以及几个疑问的解决
目录 1.从Serilizable说到transient 2.序列化属性对象的类需要实现Serilizable接口? 3.不想被序列化的字段怎么办? 4.ArrayList里面的elementData ...
- js中定时器调用函数时为什么会有引号
之前在学习的时候并没有发现的细节,关于js中,定时器的问题 这里我们写两个延时器 setTimeout(func, 0); setTimeout("func()", 0);定时器中 ...
- 使用wapiti进网站进行安全性测试
1.安装wapiti --在命令终端输入 pip install wapiti3 (因为这个结合python使用,所以安装的版本要跟python兼容,因为我的python是3.6版本,所以安装的是wa ...
- 20200221_python虚拟环境在Windows下安装配置_virtualenv不是内部或外部命令也不是可运行的程序或批处理文件
1. 使用管理员启动命令行; 2. 安装虚拟环境 a) .\pip install virtualenv -i https://pypi.douban.com/simple/ b) ...
- MongoDB去重
db.集合.aggregate([ { $group: { _id: {字段1: '$字段1',字段2: '$字段2'},count: {$sum: 1},dups: {$addToSet: '$_i ...
- Python+爬虫+xlwings发现CSDN个人博客热门文章
☞ ░ 前往老猿Python博文目录 ░ 一.引言 最近几天老猿博客的访问量出现了比较大的增长,从常规的1000-3000之间波动的范围一下子翻了将近一倍,粉丝增长从日均10-40人也增长了差不多一倍 ...
- 第10.4节 Python模块的弱封装机制
一. 引言 Python模块可以为调用者提供模块内成员的访问和调用,但某些情况下, 因为某些成员可能有特殊访问规则等原因,并不适合将模块内所有成员都提供给调用者访问,此时模块可以类似类的封装机制类似的 ...
- C++编程指南(6-7)
六.函数设计 函数是C++/C程序的基本功能单元,其重要性不言而喻.函数设计的细微缺点很容易导致该函数被错用,所以光使函数的功能正确是不够的.本章重点论述函数的接口设计和内部实现的一些规则. 函数接口 ...