作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使用 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何用 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。

项目地址:https://github.com/kingyuluk/RL-FlappyBird

增强学习(RL)的架构

在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。

CNN 训练简述

CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。

训练数据

在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。

public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}

训练的三个周期

训练分为 3 个不同的周期以更好地生成训练数据:

  • Observe(观察) 周期:随机产生训练数据
  • Explore (探索) 周期:随机与推理动作结合更新训练数据
  • Training (训练) 周期:推理动作主导产生新数据

通过这种训练模式,我们可以更好的达到预期效果。

处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。

public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}

训练逻辑

首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));

postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。

卷积神经网络模型(CNN)

我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

训练过程

DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 实现 RlAgent 接口即可构建一个可以进行训练的智能体。
  • 在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
  • 创建 ReplayBuffer 可以存储并动态更新训练数据。

在实现这些接口后,只需要调用 step 方法:

RlEnv.step(action, training);

这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。

public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}

我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。

在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。

神经网络输入预处理

首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。

public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}

然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}

一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}

总结

这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

本项目完整代码:https://github.com/kingyuluk/RL-FlappyBird

用 Java 训练出一只“不死鸟”的更多相关文章

  1. 137 Single Number II 数组中除了一个数外,其他的数都出现了三次,找出这个只出现一次的数

    给定一个整型数组,除了一个元素只出现一次外,其余每个元素都出现了三次.求出那个只出现一次的数.注意:你的算法应该具有线性的时间复杂度.你能否不使用额外的内存来实现?详见:https://leetcod ...

  2. java.sql.SQLException: 对只转发结果集的无效操作: last

    出错代码如下:static String u = "user";static String p = "psw";static String url = &quo ...

  3. 在Java中弹出位于其他类的由WindowsBuilder创建的JFrameApplicationWIndow

    我们一般在使用Java弹出窗体的时候,一般是使用Jdialog这个所谓的"对话框类".但是,如果你不是初学者或研究员,而会在使用Java进行swing项目的开发,那么你很可能用到一 ...

  4. 给出2n+1个数,其中有2n个数出现过两次,如何用最简便的方法找出里面只出现了一次的那个数(转载)

    有2n+1个数,其中有2n个数出现过两次,找出其中只出现一次的数 例如这样一组数3,3,1,2,4,2,5,5,4,其中只有1出现了1次,其他都是出现了2次,如何找出其中的1? 最简便的方法是使用异或 ...

  5. IE6多出一只猪的经典bug

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  6. java 弹出选择目录框(选择文件夹),获取选择的文件夹路径

    java 弹出选择目录框(选择文件夹),获取选择的文件夹路径 java 弹出选择目录框(选择文件夹),获取选择的文件夹路径:int result = 0;File file = null;String ...

  7. 2017年11月GitHub上最热门的Java项目出炉

    2017年11月GitHub上最热门的Java项目出炉~ 一起来看看这些项目你使用过哪些呢? 1分布式 RPC 服务框架 dubbohttps://github.com/alibaba/dubbo S ...

  8. Java释出的时候,AWT作为Java最弱的组件受到不小的批评

    Java释出的时候,AWT作为Java最弱的组件受到不小的批评. 最根本的缺点是AWT在原生的用户界面之上仅提供了一个非常薄的抽象层. 例如,生成一个AWT的 复选框会导致AWT直接调用下层原生例程来 ...

  9. 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器

    迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...

随机推荐

  1. word教程字体和段落设置

    放大/缩小字号:1.选中文字-点击"大A"或"小A" 2.同时摁着ctrl+shift+>/ctrl+shift+<即可 设置标题与正文间距:鼠标放 ...

  2. 面试官问Linux下如何编译C程序,如何回答?为你编译演示

    文章来源:嵌入式大杂烩 作者:ZhengNL Windows下常用IDE来编译,Linux下直接使用gcc来编译,编译过程是Linux嵌入式编程的基础,也是嵌入式高频基础面试问题. 一.命令行编译及各 ...

  3. java41

    2019.8.7全部回顾完毕 收获:搞懂了以前不理解的内容 学会了Markdown语法 1. 将首字母变大写 public class _02将首字母变大写 { public static void ...

  4. B. Irreducible Anagrams【CF 1290B】

    思路: 设tx为t类别字符的个数. ①对于长度小于2的t明显是"YES"②对于字符类别只有1个的t明显是"YES"③对于字符类别有2个的t,如左上图:如果str ...

  5. IdentityServer4系列 | 资源密码凭证模式

    一.前言 从上一篇关于客户端凭证模式中,我们通过创建一个认证授权访问服务,定义一个API和要访问它的客户端,客户端通过IdentityServer上请求访问令牌,并使用它来控制访问API.其中,我们也 ...

  6. 【五校联考1day2】JZOJ2020年8月12日提高组T2 我想大声告诉你

    [五校联考1day2]JZOJ2020年8月12日提高组T2 我想大声告诉你 题目 Description 因为小Y 是知名的白富美,所以自然也有很多的追求者,这一天这些追求者打算进行一次游戏来踢出一 ...

  7. 跟我一起学Redis之Redis持久化必知必会

    前言 Redis是出了名的速度快,那是因为在内存中进行数据存储和操作:如果仅仅是在内存中进行数据存储,那就会导致以下问题: 数据随进程退出而消失:当服务器断电或Redis Server进程退出时,内存 ...

  8. 基于 MongoDB 动态字段设计的探索

    一.业务需求 假设某学校课程系统,不同专业课程不同 (可以动态增删),但是需要根据专业不同显示该专业学生的各科课程的成绩,如下: 专业 姓名 高等数学 数据结构 计算机 张三 90 85 计算机 李四 ...

  9. 20200513_安装windows sql server 2012 _ ws功能 NetFx3时出错,错误代码:-2146498298

    这是没有安装.net Framework 3.5造成的 1. 下载个.net Framework 3.5, 放到任意目录下, 比如C: 2. 打开添加windows 功能 3. 直接下一步: 4. 勾 ...

  10. 老猿学5G:融合计费场景的Nchf_ConvergedCharging_Create、Update和Release融合计费消息交互过程

    ☞ ░ 前往老猿Python博文目录 ░ 一.Nchf_ConvergedCharging_Create交互过程 Nchf_ConvergedCharging_Create 服务为CTF向CHF请求提 ...