作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使用 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何用 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。

项目地址:https://github.com/kingyuluk/RL-FlappyBird

增强学习(RL)的架构

在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。

CNN 训练简述

CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。

训练数据

在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。

public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}

训练的三个周期

训练分为 3 个不同的周期以更好地生成训练数据:

  • Observe(观察) 周期:随机产生训练数据
  • Explore (探索) 周期:随机与推理动作结合更新训练数据
  • Training (训练) 周期:推理动作主导产生新数据

通过这种训练模式,我们可以更好的达到预期效果。

处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。

public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}

训练逻辑

首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));

postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。

卷积神经网络模型(CNN)

我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

训练过程

DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 实现 RlAgent 接口即可构建一个可以进行训练的智能体。
  • 在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
  • 创建 ReplayBuffer 可以存储并动态更新训练数据。

在实现这些接口后,只需要调用 step 方法:

RlEnv.step(action, training);

这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。

public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}

我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。

在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。

神经网络输入预处理

首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。

public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}

然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}

一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}

总结

这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

本项目完整代码:https://github.com/kingyuluk/RL-FlappyBird

用 Java 训练出一只“不死鸟”的更多相关文章

  1. 137 Single Number II 数组中除了一个数外,其他的数都出现了三次,找出这个只出现一次的数

    给定一个整型数组,除了一个元素只出现一次外,其余每个元素都出现了三次.求出那个只出现一次的数.注意:你的算法应该具有线性的时间复杂度.你能否不使用额外的内存来实现?详见:https://leetcod ...

  2. java.sql.SQLException: 对只转发结果集的无效操作: last

    出错代码如下:static String u = "user";static String p = "psw";static String url = &quo ...

  3. 在Java中弹出位于其他类的由WindowsBuilder创建的JFrameApplicationWIndow

    我们一般在使用Java弹出窗体的时候,一般是使用Jdialog这个所谓的"对话框类".但是,如果你不是初学者或研究员,而会在使用Java进行swing项目的开发,那么你很可能用到一 ...

  4. 给出2n+1个数,其中有2n个数出现过两次,如何用最简便的方法找出里面只出现了一次的那个数(转载)

    有2n+1个数,其中有2n个数出现过两次,找出其中只出现一次的数 例如这样一组数3,3,1,2,4,2,5,5,4,其中只有1出现了1次,其他都是出现了2次,如何找出其中的1? 最简便的方法是使用异或 ...

  5. IE6多出一只猪的经典bug

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  6. java 弹出选择目录框(选择文件夹),获取选择的文件夹路径

    java 弹出选择目录框(选择文件夹),获取选择的文件夹路径 java 弹出选择目录框(选择文件夹),获取选择的文件夹路径:int result = 0;File file = null;String ...

  7. 2017年11月GitHub上最热门的Java项目出炉

    2017年11月GitHub上最热门的Java项目出炉~ 一起来看看这些项目你使用过哪些呢? 1分布式 RPC 服务框架 dubbohttps://github.com/alibaba/dubbo S ...

  8. Java释出的时候,AWT作为Java最弱的组件受到不小的批评

    Java释出的时候,AWT作为Java最弱的组件受到不小的批评. 最根本的缺点是AWT在原生的用户界面之上仅提供了一个非常薄的抽象层. 例如,生成一个AWT的 复选框会导致AWT直接调用下层原生例程来 ...

  9. 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器

    迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...

随机推荐

  1. jQuery 小demo 热点排名

    效果如下: 代码如下: 1 <!DOCTYPE html> 2 <html lang="en"> 3 <head> 4 <meta cha ...

  2. C语言讲义——变量(variable)

    变量(variable) 变量用于存放数据 变量是供程序操作的存储区的名字 变量有类型,该类型决定了变量占用内存的大小 字节→ C语言有以下6种简单变量类型: 类型细分: 变量在内存中需要占据空间,内 ...

  3. MapReduce怎么优雅地实现全局排序

    思考 想到全局排序,是否第一想到的是,从map端收集数据,shuffle到reduce来,设置一个reduce,再对reduce中的数据排序,显然这样和单机器并没有什么区别,要知道mapreduce框 ...

  4. 在Python中使用moviepy进行视频剪辑时输出文件报错 ‘NoneType‘ object has no attribute ‘stdout‘问题

    专栏:Python基础教程目录 专栏:使用PyQt开发图形界面Python应用 专栏:PyQt入门学习 老猿Python博文目录 老猿学5G博文目录 movipy输出文件时报错 'NoneType' ...

  5. 第十章、Qt Designer中的Spacers部件

    老猿Python博文目录 专栏:使用PyQt开发图形界面Python应用 老猿Python博客地址 一. 引言 在Designer的部件栏中,有两种类型的Spacers部件,下图中上面布局中为一个水平 ...

  6. java课后作业2019.11.04

    一.编写一个程序,指定一个文件夹,能够自动计算出其总容量 1.代码 package HomeWork; import java.io.File; public class getFileDaxiao ...

  7. Asp.NetCore之AutoMapper进阶篇

    应用场景 在上一篇文章--Asp.NetCore之AutoMapper基础篇中我们简单介绍了一些AutoMapper的基础用法以及如何在.NetCore中实现快速开发.我相信用过AutoMapper实 ...

  8. PHP代码审计分段讲解(11)

    后面的题目相对于之前的题目难度稍微提升了一些,所以对每道题进行单独的分析 27题 <?php if(!$_GET['id']) { header('Location: index.php?id= ...

  9. [GXYCTF2019] MISC杂项题

    buuoj复现 1,佛系青年 下载了之后是一个加密的txt文件和一张图片 分析图片无果,很讨厌这种脑洞题,MISC应该给一点正常的线索加部分脑洞而不是出干扰信息来故意让选手走错方向,当时比赛做这道题的 ...

  10. Docker 安装-在centos7下安装Docker(二)

    参考docker安装的方式: http://www.runoob.com/docker/centos-docker-install.html Docker中文官网安装步骤:https://docs.d ...