作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使用 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何用 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。

项目地址:https://github.com/kingyuluk/RL-FlappyBird

增强学习(RL)的架构

在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。

CNN 训练简述

CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。

训练数据

在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。

public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}

训练的三个周期

训练分为 3 个不同的周期以更好地生成训练数据:

  • Observe(观察) 周期:随机产生训练数据
  • Explore (探索) 周期:随机与推理动作结合更新训练数据
  • Training (训练) 周期:推理动作主导产生新数据

通过这种训练模式,我们可以更好的达到预期效果。

处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。

public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}

训练逻辑

首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));

postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。

卷积神经网络模型(CNN)

我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

训练过程

DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。

  • 实现 RlAgent 接口即可构建一个可以进行训练的智能体。
  • 在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。
  • 创建 ReplayBuffer 可以存储并动态更新训练数据。

在实现这些接口后,只需要调用 step 方法:

RlEnv.step(action, training);

这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。

public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}

我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。

在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。

神经网络输入预处理

首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。

public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}

然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。

public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}

一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}

总结

这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

本项目完整代码:https://github.com/kingyuluk/RL-FlappyBird

用 Java 训练出一只“不死鸟”的更多相关文章

  1. 137 Single Number II 数组中除了一个数外,其他的数都出现了三次,找出这个只出现一次的数

    给定一个整型数组,除了一个元素只出现一次外,其余每个元素都出现了三次.求出那个只出现一次的数.注意:你的算法应该具有线性的时间复杂度.你能否不使用额外的内存来实现?详见:https://leetcod ...

  2. java.sql.SQLException: 对只转发结果集的无效操作: last

    出错代码如下:static String u = "user";static String p = "psw";static String url = &quo ...

  3. 在Java中弹出位于其他类的由WindowsBuilder创建的JFrameApplicationWIndow

    我们一般在使用Java弹出窗体的时候,一般是使用Jdialog这个所谓的"对话框类".但是,如果你不是初学者或研究员,而会在使用Java进行swing项目的开发,那么你很可能用到一 ...

  4. 给出2n+1个数,其中有2n个数出现过两次,如何用最简便的方法找出里面只出现了一次的那个数(转载)

    有2n+1个数,其中有2n个数出现过两次,找出其中只出现一次的数 例如这样一组数3,3,1,2,4,2,5,5,4,其中只有1出现了1次,其他都是出现了2次,如何找出其中的1? 最简便的方法是使用异或 ...

  5. IE6多出一只猪的经典bug

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  6. java 弹出选择目录框(选择文件夹),获取选择的文件夹路径

    java 弹出选择目录框(选择文件夹),获取选择的文件夹路径 java 弹出选择目录框(选择文件夹),获取选择的文件夹路径:int result = 0;File file = null;String ...

  7. 2017年11月GitHub上最热门的Java项目出炉

    2017年11月GitHub上最热门的Java项目出炉~ 一起来看看这些项目你使用过哪些呢? 1分布式 RPC 服务框架 dubbohttps://github.com/alibaba/dubbo S ...

  8. Java释出的时候,AWT作为Java最弱的组件受到不小的批评

    Java释出的时候,AWT作为Java最弱的组件受到不小的批评. 最根本的缺点是AWT在原生的用户界面之上仅提供了一个非常薄的抽象层. 例如,生成一个AWT的 复选框会导致AWT直接调用下层原生例程来 ...

  9. 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器

    迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...

随机推荐

  1. 如何使用系统清理缓存软件优化MacBook

    在我们使用我们的Mac一定的时间后,总是不可避免的出现Mac内存不足的情况,所以清理垃圾软件也就成为了我们电脑里必不可少的软件.苹果软件商店中有很多各有不同的清理垃圾软件,但我们往往很难从这一大堆软件 ...

  2. 网络系列之GET与POST请求方式的区别

    作为一枚正在学习前端的 小萌新,如果下面哪里有写的不对的话,可以帮我指出来吗,谢谢 1.是基于什么前提的?如果什么前提都没有,不使用任何规范,只考虑语法和理论上的HTTP协议 那么GET和POST几乎 ...

  3. 【VUE】7.Vuex基本使用

    1. 安装Vuex npm install vuex --save 2. 导入Vuex包 import Vuex from 'vuex' Vue.use(Vuex) 3. 创建store对象 cons ...

  4. 基于Koa2+mongoDB的后端博客框架

    主要框架:koa2全家桶+mongoose+pm2. 在阅读前建议将项目克隆到本地配合食用,否则将看得云里雾里. 项目地址:https://github.com/YogurtQ/koa-server. ...

  5. Vue看板娘教程1.0

    Live2D看板娘 前言(PS:本教程使用的Vue项目) 一.下载文件 二.使用步骤 1.引入文件 2.引入js 3.修改app.vue 4.如何换模型? 更换模型的效果 5.如何换语音? 结尾(后续 ...

  6. 自学linux——1.VMware的安装及VM下centos的安装

    1.CentOS下载 网址:https://www.centos.org/download/ 网盘:https://pan.baidu.com/s/1HrtK6xNig6KC8oh6O-6fyg 提取 ...

  7. XOR性质

    异或XOR的性质: 1. 交换律 2. 结合律 3. x^x = 0 -> 偶数个异或为0 4. x^0 = x -> 奇数个异或为本身 5. 自反性:a^b^b = a^0 =a

  8. SVN报错working copy is not uptodate

    报错信息 回想了下我更改的信息:删除了一些包,增加了一些包,删除了文件,增加了文件. 解决操作:先更新,然后提交试下,又报了以下错误 解决操作:右键项目,team->show tree conf ...

  9. k8s 部署 Java 项目

    前几天安装了 k8s 并测试了自动伸缩功能(HPA),今天来部署一个简单的 Java 应用到 k8s. 开始之前需要先安装一下 ingress 插件.ingress 作为 k8s 的流量入口,有多种实 ...

  10. 编写测试用例 QQ账号6--10位自然数 某城市电话号码 126邮箱注册功能