参考文献

莫凡系列课程视频

增强学习入门之Q-Learning

关于增强学习的基本知识可以参考第二个链接,讲的挺有意思的。DQN的东西可以看第一个链接相关视频。课程中实现了Tensorflow和pytorch的示例代码。本文主要是改写成了gluon实现

Q-learning的算法流程

DQN的算法流程

对于DQN的理解:

增强学习中需要学习的东西是Q-table,决策表。而针对于state space空间太大的情形,很难甚至不可能构建这个决策表。而决策表其实就是一种映射 (s,a)->R, 那么这种映射可以通过网络来构建,于是就有了DQN

下面来看代码

import mxnet as mx

import mxnet.ndarray as nd

import mxnet.gluon as gluon

import numpy as np

import mxnet.gluon.nn as nn

import gym

BATCH_SIZE=64                                             # 训练网络时的batchsize

LR=0.01                                                         # 权重更新的学习率

EPSILON=0.9                                                  # 每次以概率选择最有策略,有点类似于生物算法的思想

GAMMA=0.5                                                    # 计算q_target是下一个状态收益对当前的影响

TARGET_REPLACE_ITER=100                            # 保存网络参数,可以理解为上一次的映射,的频率

MEMORY_CAPACITY=1000                                # 历史决策

env = gym.make('CartPole-v0')                         # 调用OpenAI.gym构建的env

env = env.unwrapped

N_ACTIONS=env.action_space.n                       # 备选策略的个数

N_STATES = env.observation_space.shape[0]    # 状态向量的长度

# 定义所需要的网络,示例仅随意设置了几层

class Net(nn.HybridBlock):
     def __init__(self,**kwargs):
         super(Net, self).__init__(**kwargs)
         with self.name_scope():
             self.fc1 = nn.Dense(16, activation='relu')
             self.fc2 = nn.Dense(32, activation='relu')
             self.fc3 = nn.Dense(16, activation='relu')
             self.out = nn.Dense(N_ACTIONS)
     def hybrid_forward(self, F, x):
         x = self.fc1(x)
         x = self.fc2(x)
         x = self.fc3(x)
         actions_value = self.out(x)
         return actions_value

# 定义网络权重的拷贝方法。主要是因为DQN learning中采用off-policy更新,也就是说需要上一次的映射图,这可以使用网络上一次的权重保存,这个用以保存权重的网络只有前向功能,类似于查表,并不更新参数,直到满足一定条件时将当前网络参数再次存储

def copy_params(src, dst):
     dst.initialize(force_reinit=True, ctx=mx.cpu())
     layer_names = ['dense0_weight', 'dense0_bias','dense1_weight','dense1_bias',
                  'dense2_weight','dense2_bias','dense3_weight','dense3_bias']
     for i in range(len(layer_names)):
         dst.get(layer_names[i]).set_data(src.get(layer_names[i]).data())

# 定义DQN类,包含网络、策略选择、保存记录等

class DQN(object):
     def __init__(self):
         self.eval_net, self.target_net = Net(), Net()
         self.eval_net.initialize()
         self.target_net.initialize()
         x=nd.random_uniform(shape=(1,N_STATES))
         _ = self.eval_net(x)
         _ = self.target_net(x)                # mxnet的延迟初始化特性
         self.learn_step_counter = 0
         self.memory_counter = 0
         self.memory = np.zeros(shape=(MEMORY_CAPACITY, N_STATES*2+2))
         # 每一行存储的是当前状态,选择的action, 当前的回报, 下一步的状态
         self.trainer = gluon.Trainer(self.eval_net.collect_params(), 'sgd',\
                                     {'learning_rate': LR,'wd':1e-4})
         self.loss_func = gluon.loss.L2Loss()
         self.cost_his=[]
     def choose_action(self, x):
         if np.random.uniform()<EPSILON:
             # EPSILON的概率选择最可能动作
             x = nd.array([x])
             actions_value = self.eval_net(x)
             action = int(nd.argmax(actions_value, axis=1).asscalar())
         else:
             action = np.random.randint(0, N_ACTIONS)
         return action
     def store_transition(self,s,a,r,s_):
         # 存储历史纪录
         transition = np.hstack((s,[a,r],s_))
         index = self.memory_counter % MEMORY_CAPACITY
         # 主要是为了循环利用存储空间
         self.memory[index,:] = transition
         self.memory_counter += 1
        
     def learn(self):
         if self.learn_step_counter % TARGET_REPLACE_ITER==0:
             # 每学习一定间隔之后,将当前的状态
             copy_params(self.eval_net.collect_params(), self.target_net.collect_params())
            
         self.learn_step_counter += 1
        
         sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
         # 随机选择一组状态
         b_memory = self.memory[sample_index,:]
       
         b_s = nd.array(b_memory[:,:N_STATES])
         b_a = nd.array(b_memory[:,N_STATES:N_STATES+1])
         b_r = nd.array(b_memory[:,N_STATES+1:N_STATES+2])
         b_s_= nd.array(b_memory[:,-N_STATES:])
         with mx.autograd.record():
             q_eval = self.eval_net(b_s) # 预估值
             with mx.autograd.pause():
                 q_next = self.target_net(b_s_) # 历史值 batch x N_ACTIONS
             q_target = b_r + GAMMA*nd.max(q_next, axis=1)
             loss = self.loss_func(q_eval, q_target)
        
         self.cost_his.append(nd.mean(loss).asscalar())
         loss.backward()
         self.trainer.step(BATCH_SIZE)
        
     def plot_cost(self):
         import matplotlib.pyplot as plt
         plt.plot(np.arange(len(self.cost_his)), self.cost_his)
         plt.ylabel('Cost')
         plt.xlabel('training steps')
         plt.show()

# 训练

dqn = DQN()

for i_episode in range(500):
     s = env.reset()
     while True:
         env.render()
         a = dqn.choose_action(s)
         s_, r, done, info = env.step(a)# 到达的状态,收益,是否结束

x,x_dot, theta, theta_dot = s_
         r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
         r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians-0.5
         r = r1 + r2

dqn.store_transition(s,a,r,s_)
         if dqn.memory_counter > MEMORY_CAPACITY:
             dqn.learn()

if done:
             break
        
         s = s_

dqn.plot_cost()

loss曲线

训练的loss似乎并没有收敛,还在找原因

ps. 第一次使用open live writer写博客,体验很差!!!!!我需要公式、代码和图片的支持。。。。还在寻找中

mxnet(gluon) 实现DQN简单小例子的更多相关文章

  1. php+jquery+ajax+json简单小例子

    直接贴代码: <html> <title>php+jquery+ajax+json简单小例子</title> <?php header("Conte ...

  2. C#利用事件与委托进行窗体间传值简单小例子

    本篇博客是利用C#委托与事件进行窗体间传值的简单小例子 委托与事件的详细解释大家可以参照张子阳的博客: http://www.tracefact.net/CSharp-Programming/Dele ...

  3. ASP.NET Cookie对象到底是毛啊?(简单小例子)

    记得刚接触asp.net的时候,就被几个概念搞的头痛不已,比如Request,Response,Session和Cookie.然后还各种在搜索引擎搜,各种问同事的,但是结果就是自己还是很懵的节奏. 那 ...

  4. 关键字Lock的简单小例子

    一.什么是Lock? Lock——字面上理解就是锁上:锁住:把……锁起来的意思: 为什么要锁?要锁干什么?——回到现实中可想象到,这个卫生间我要上,其他人不要进来!(所以我要锁住门):又或者土味情话所 ...

  5. 详细解读Android中的搜索框(一)—— 简单小例子

    这次开的是一个讲解SearchView的栏目,第一篇主要是给一个小例子,让大家对这个搜索视图有一个了解,之后再分布细化来说. 目标: 我们先来定个目标,我们通过搜索框来输入要搜索的联系人名字,输入的时 ...

  6. 关于ExpandableListView用法的一个简单小例子

    喜欢显示好友QQ那样的列表,可以展开,可以收起,在android中,以往用的比较多的是listview,虽然可以实现列表的展示,但在某些情况下,我们还是希望用到可以分组并实现收缩的列表,那就要用到an ...

  7. Ajax的简单小例子

    1.首先下载ajax.dll,一个百度一下都有下载的!自行查找. 2.把ajax.dll导入到工程.右键工程-->添加引用--->浏览,找到下载好的ajax.dll文件,点击确定,这时候在 ...

  8. SpringMVC静态文件(图片)访问+js访问 简单小例子

    项目文件布局: web.xml文件: <?xml version="1.0" encoding="UTF-8"?> <web-app vers ...

  9. MVC实现(简单小例子)

    Here I’ll demonstrate simple Spring MVC framework for building web applications. First thing first. ...

随机推荐

  1. linux 中的定时任务crontab使用方法

    linux 中的定时任务crontab使用方法: 切换到root用户,sudo su root (可以设置成不需要输入密码) sudo su - (需要输入当前帐号的密码才能进入.) crontab ...

  2. Python之路----递归函数

    1.小练一下 用map来处理字符串列表,把列表中所有人都变成sb,比方alex_sb name=['alex','wupeiqi','yuanhao','nezha'] # def func(item ...

  3. 什么是IO多路复用?Nginx的处理机制

    先来说一下什么是IO复用? IO复用解决的就是并发行的问题,比如多个用户并发访问一个WEB网站,对于服务端后台而言就会产生多个请求,处理多个请求对于中间件就会产生多个IO流对于系统的读写.那么对于IO ...

  4. bzoj1642 / P2889 [USACO07NOV]挤奶的时间Milking Time

    P2889 [USACO07NOV]挤奶的时间Milking Time 普通的dp 休息时间R其实就是把结束时间后移R个单位而已.但是终点也需要后移R位到n+R. 每个时间段按起始时间排序,蓝后跑一遍 ...

  5. 更换 nodejs npm 镜像为 淘宝 镜像

    淘宝npm镜像官方介绍文档:https://npm.taobao.org/ ,使用命令在这个官方文档里查询. 安装工具cnpm: $ npm install -g cnpm --registry=ht ...

  6. HttpClient4.5简单使用

    一.HttpClient简介 HttpClient是一个客户端的HTTP通信实现库,它不是一个浏览器.关于HTTP协议,可以搜索相关的资料.它设计的目的是发送与接收HTTP报文.它不会执行嵌入在页面中 ...

  7. cogs 2223. [SDOI2016 Round1] 生成魔咒

    ★★☆ 输入文件:menci_incantation.in 输出文件:menci_incantation.out 简单对比 时间限制:1 s 内存限制:128 MB [题目描述]魔咒串由许多魔咒字符组 ...

  8. nmap参数思维导图

    链接:https://pan.baidu.com/s/1vD0A6olQbVNmCCirpHBm0w 提取码:o994

  9. 如何修改bootstrap模态框的backdrop蒙版区域的颜色?

    参考地址: http://www.cnblogs.com/9miao/p/4988196.html 蒙板样式实现: 大家或许注意到了,在做模态弹出窗时,底部常常会有一个透明的黑色蒙层效果:在Boots ...

  10. 以太坊(Ethereum) - 节点时间未同步和区块同步失败案例分析

    背景 以太坊技术搭建的区块链网络,节点间需要保证时间一致,才能正常有序的发送交易和生成区块,使得众多节点共同维护分布式账本(区块数据+状态数据).但是,网络中节点的系统时间不一致回出现什么现象呢,我们 ...