Actor Critic value-based和policy-based的结合

实例代码

 import sys
import gym
import pylab
import numpy as np
from keras.layers import Dense
from keras.models import Sequential
from keras.optimizers import Adam EPISODES = 1000 # A2C(Advantage Actor-Critic) agent for the Cartpole
# actor-critic算法结合了value-based和policy-based方法
class A2CAgent:
def __init__(self, state_size, action_size):
# if you want to see Cartpole learning, then change to True
self.render = True
self.load_model = False
# get size of state and action
self.state_size = state_size
self.action_size = action_size
self.value_size = 1 # These are hyper parameters for the Policy Gradient
self.discount_factor = 0.99
self.actor_lr = 0.001
self.critic_lr = 0.005 # create model for policy network
self.actor = self.build_actor()
self.critic = self.build_critic() if self.load_model:
self.actor.load_weights("./save_model/cartpole_actor.h5")
self.critic.load_weights("./save_model/cartpole_critic.h5") # approximate policy and value using Neural Network
# actor: state is input and probability of each action is output of model
def build_actor(self):#actor网络:state-->action
actor = Sequential()
actor.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
actor.add(Dense(self.action_size, activation='softmax',
kernel_initializer='he_uniform'))
actor.summary()
# See note regarding crossentropy in cartpole_reinforce.py
actor.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=self.actor_lr))
return actor # critic: state is input and value of state is output of model
def build_critic(self):#critic网络:state-->value,Q值
critic = Sequential()
critic.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
critic.add(Dense(self.value_size, activation='linear',
kernel_initializer='he_uniform'))
critic.summary()
critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
return critic # using the output of policy network, pick action stochastically
def get_action(self, state):
policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
return np.random.choice(self.action_size, 1, p=policy)[0] # update policy network every episode
def train_model(self, state, action, reward, next_state, done):
target = np.zeros((1, self.value_size))#(1,1)
advantages = np.zeros((1, self.action_size))#(1, 2) value = self.critic.predict(state)[0]#critic网络预测的当前q值
next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值 '''
理解下面部分
'''
if done:
advantages[0][action] = reward - value
target[0][0] = reward
else:
advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
target[0][0] = reward + self.discount_factor * next_value#critic网络 self.actor.fit(state, advantages, epochs=1, verbose=0)
self.critic.fit(state, target, epochs=1, verbose=0) if __name__ == "__main__":
# In case of CartPole-v1, maximum length of episode is 500
env = gym.make('CartPole-v1')
# get size of state and action from environment
state_size = env.observation_space.shape[0]
action_size = env.action_space.n # make A2C agent
agent = A2CAgent(state_size, action_size)
scores, episodes = [], [] for e in range(EPISODES):
done = False
score = 0
state = env.reset()
state = np.reshape(state, [1, state_size]) while not done:
if agent.render:
env.render() action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
# if an action make the episode end, then gives penalty of -100
reward = reward if not done or score == 499 else -100 agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次 score += reward
state = next_state if done:
# every episode, plot the play time
score = score if score == 500.0 else score + 100
scores.append(score)
episodes.append(e)
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/cartpole_a2c.png")
print("episode:", e, " score:", score) # if the mean of scores of last 10 episode is bigger than 490
# stop training
if np.mean(scores[-min(10, len(scores)):]) > 490:
sys.exit() # save the model
if e % 50 == 0:
agent.actor.save_weights("./save_model/cartpole_actor.h5")
agent.critic.save_weights("./save_model/cartpole_critic.h5")

深度增强学习--Actor Critic的更多相关文章

  1. 深度增强学习--DDPG

    DDPG DDPG介绍2 ddpg输出的不是行为的概率, 而是具体的行为, 用于连续动作 (continuous action) 的预测 公式推导 推导 代码实现的gym的pendulum游戏,这个游 ...

  2. 深度增强学习--A3C

    A3C 它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所 ...

  3. 深度增强学习--DPPO

    PPO DPPO介绍 PPO实现 代码DPPO

  4. 深度增强学习--DQN的变形

    DQN的变形 double DQN prioritised replay dueling DQN

  5. 深度增强学习--Policy Gradient

    前面都是value based的方法,现在看一种直接预测动作的方法 Policy Based Policy Gradient 一个介绍 karpathy的博客 一个推导 下面的例子实现的REINFOR ...

  6. 深度增强学习--Deep Q Network

    从这里开始换个游戏演示,cartpole游戏 Deep Q Network 实例代码 import sys import gym import pylab import random import n ...

  7. 常用增强学习实验环境 II (ViZDoom, Roboschool, TensorFlow Agents, ELF, Coach等) (转载)

    原文链接:http://blog.csdn.net/jinzhuojun/article/details/78508203 前段时间Nature上发表的升级版Alpha Go - AlphaGo Ze ...

  8. 马里奥AI实现方式探索 ——神经网络+增强学习

    [TOC] 马里奥AI实现方式探索 --神经网络+增强学习 儿时我们都曾有过一个经典游戏的体验,就是马里奥(顶蘑菇^v^),这次里约奥运会闭幕式,日本作为2020年东京奥运会的东道主,安倍最后也已经典 ...

  9. 增强学习 | AlphaGo背后的秘密

    "敢于尝试,才有突破" 2017年5月27日,当今世界排名第一的中国棋手柯洁与AlphaGo 2.0的三局对战落败.该事件标志着最新的人工智能技术在围棋竞技领域超越了人类智能,借此 ...

随机推荐

  1. web前端 CSS基础

    简单的CSS文件 <style type="text/css"> a{ color:rebeccapurple; font-size: larger; font-wei ...

  2. git web 服务器的搭建【转】

    转自:http://blog.csdn.net/transformer_han/article/details/6450200 目录(?)[-] git服务器搭建过程 需求 硬件需求一台Ubuntu或 ...

  3. appium===出错时截图的方法,自动截图

    try: driver.find_element_by_id("kwsss").send_keys("selenium") driver.find_elemen ...

  4. Flex slider参数详细

    $(window).load(function() { $('.flexslider').flexslider({ animation: "fade", //String: Sel ...

  5. vmware的3种网络模式

    ####图片以及部分内容来源:https://note.youdao.com/share/?id=236896997b6ffbaa8e0d92eacd13abbf&type=note#/ 在安 ...

  6. 经常用到的Eclipse快捷键(更新中....)

    alt+shift+s:弹出Source选项,用于生成get,set等方法. Ctrl+E:快速显示当前Editer的下拉列表 alt+shift+r:重命名 Ctrl+Shift+→/Ctrl+Sh ...

  7. [orangehrm] 安装问题集合

    Web server allows .htaccess files # 这一项检查不通过 解决: In conf/extra/httpd-vhosts.conf, add the line Allow ...

  8. scrapy 最新版本中文文档地址

    http://scrapy-chs.readthedocs.org/zh_CN/latest/

  9. 安装和破解Quartus Ⅱ 15.0

    http://jingyan.baidu.com/article/b7001fe18d47fc0e7282dd91.html

  10. LCA+差分【CF191C】Fools and Roads

    Description 有一颗 \(n\) 个节点的树,\(k\) 次旅行,问每一条边被走过的次数. Input 第一行一个整数 \(n\) (\(2\leq n\leq 10^5\)). 接下来 \ ...