Actor Critic value-based和policy-based的结合

实例代码

 import sys
import gym
import pylab
import numpy as np
from keras.layers import Dense
from keras.models import Sequential
from keras.optimizers import Adam EPISODES = 1000 # A2C(Advantage Actor-Critic) agent for the Cartpole
# actor-critic算法结合了value-based和policy-based方法
class A2CAgent:
def __init__(self, state_size, action_size):
# if you want to see Cartpole learning, then change to True
self.render = True
self.load_model = False
# get size of state and action
self.state_size = state_size
self.action_size = action_size
self.value_size = 1 # These are hyper parameters for the Policy Gradient
self.discount_factor = 0.99
self.actor_lr = 0.001
self.critic_lr = 0.005 # create model for policy network
self.actor = self.build_actor()
self.critic = self.build_critic() if self.load_model:
self.actor.load_weights("./save_model/cartpole_actor.h5")
self.critic.load_weights("./save_model/cartpole_critic.h5") # approximate policy and value using Neural Network
# actor: state is input and probability of each action is output of model
def build_actor(self):#actor网络:state-->action
actor = Sequential()
actor.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
actor.add(Dense(self.action_size, activation='softmax',
kernel_initializer='he_uniform'))
actor.summary()
# See note regarding crossentropy in cartpole_reinforce.py
actor.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=self.actor_lr))
return actor # critic: state is input and value of state is output of model
def build_critic(self):#critic网络:state-->value,Q值
critic = Sequential()
critic.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
critic.add(Dense(self.value_size, activation='linear',
kernel_initializer='he_uniform'))
critic.summary()
critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
return critic # using the output of policy network, pick action stochastically
def get_action(self, state):
policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
return np.random.choice(self.action_size, 1, p=policy)[0] # update policy network every episode
def train_model(self, state, action, reward, next_state, done):
target = np.zeros((1, self.value_size))#(1,1)
advantages = np.zeros((1, self.action_size))#(1, 2) value = self.critic.predict(state)[0]#critic网络预测的当前q值
next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值 '''
理解下面部分
'''
if done:
advantages[0][action] = reward - value
target[0][0] = reward
else:
advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
target[0][0] = reward + self.discount_factor * next_value#critic网络 self.actor.fit(state, advantages, epochs=1, verbose=0)
self.critic.fit(state, target, epochs=1, verbose=0) if __name__ == "__main__":
# In case of CartPole-v1, maximum length of episode is 500
env = gym.make('CartPole-v1')
# get size of state and action from environment
state_size = env.observation_space.shape[0]
action_size = env.action_space.n # make A2C agent
agent = A2CAgent(state_size, action_size)
scores, episodes = [], [] for e in range(EPISODES):
done = False
score = 0
state = env.reset()
state = np.reshape(state, [1, state_size]) while not done:
if agent.render:
env.render() action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
# if an action make the episode end, then gives penalty of -100
reward = reward if not done or score == 499 else -100 agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次 score += reward
state = next_state if done:
# every episode, plot the play time
score = score if score == 500.0 else score + 100
scores.append(score)
episodes.append(e)
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/cartpole_a2c.png")
print("episode:", e, " score:", score) # if the mean of scores of last 10 episode is bigger than 490
# stop training
if np.mean(scores[-min(10, len(scores)):]) > 490:
sys.exit() # save the model
if e % 50 == 0:
agent.actor.save_weights("./save_model/cartpole_actor.h5")
agent.critic.save_weights("./save_model/cartpole_critic.h5")

深度增强学习--Actor Critic的更多相关文章

  1. 深度增强学习--DDPG

    DDPG DDPG介绍2 ddpg输出的不是行为的概率, 而是具体的行为, 用于连续动作 (continuous action) 的预测 公式推导 推导 代码实现的gym的pendulum游戏,这个游 ...

  2. 深度增强学习--A3C

    A3C 它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所 ...

  3. 深度增强学习--DPPO

    PPO DPPO介绍 PPO实现 代码DPPO

  4. 深度增强学习--DQN的变形

    DQN的变形 double DQN prioritised replay dueling DQN

  5. 深度增强学习--Policy Gradient

    前面都是value based的方法,现在看一种直接预测动作的方法 Policy Based Policy Gradient 一个介绍 karpathy的博客 一个推导 下面的例子实现的REINFOR ...

  6. 深度增强学习--Deep Q Network

    从这里开始换个游戏演示,cartpole游戏 Deep Q Network 实例代码 import sys import gym import pylab import random import n ...

  7. 常用增强学习实验环境 II (ViZDoom, Roboschool, TensorFlow Agents, ELF, Coach等) (转载)

    原文链接:http://blog.csdn.net/jinzhuojun/article/details/78508203 前段时间Nature上发表的升级版Alpha Go - AlphaGo Ze ...

  8. 马里奥AI实现方式探索 ——神经网络+增强学习

    [TOC] 马里奥AI实现方式探索 --神经网络+增强学习 儿时我们都曾有过一个经典游戏的体验,就是马里奥(顶蘑菇^v^),这次里约奥运会闭幕式,日本作为2020年东京奥运会的东道主,安倍最后也已经典 ...

  9. 增强学习 | AlphaGo背后的秘密

    "敢于尝试,才有突破" 2017年5月27日,当今世界排名第一的中国棋手柯洁与AlphaGo 2.0的三局对战落败.该事件标志着最新的人工智能技术在围棋竞技领域超越了人类智能,借此 ...

随机推荐

  1. Opencv 中透视变换函数对IplImage图像变换时出现的问题?

    最近一直在做视频稳像的项目,为了简化部分实现,使用了部分Opencv的函数,其中包括Opencv中对IplImage进行同时变换的函数cvWarpPerspective(src, dst,...) 发 ...

  2. Symfony2之创建一个简单的web应用 Symfony2——创建bundle

    bundle就像插件或者一个功能齐全的应用,我们在应用层上开发的应用的所有代码,包括:PHP文件.配置文件.图片.css文件.js文件等都会包含在bunde系统中.          可以通过两种方法 ...

  3. HDU2444(判断是否为二分图,求最大匹配)

    The Accomodation of Students Time Limit: 5000/1000 MS (Java/Others)    Memory Limit: 32768/32768 K ( ...

  4. 命令行工具PathMarker

    一直使用Guake 终端,Guake提供的其中一个功能是快速打开. 大概的意思就是,显示在终端上的数据会经过匹配,如果符合一定的规则,则可以按住ctrl,使用鼠标单击以触发指定操作. 比如对于一个文件 ...

  5. 常见协议基础知识总结--DHCP协议

    DHCP动态主机配置协议,简单点说,就是提供了自动获取ip地址的功能,基于四层的UDP协议: 以下描述此协议的整个工作流程: (1) 客户端发送discovery报文,二三层广播报文,源ip地址全0: ...

  6. appium===元素定位

    一.常用识别元素的工具 uiautomator:Android SDK自带的一个工具,在tools目录下 monitor:Android SDK自带的一个工具,在tools目录下 Appium Ins ...

  7. 搭建Spring框架,实现添加数据到数据库

    原创链接:http://www.cnblogs.com/yanqin/p/5284400.html (允许转载,但请注明原创链接) 1.在myeclipse中建立一个web项目 项目名 :spring ...

  8. pygame --- 可怜的小乌龟

    来于----@小甲鱼工作室 import pygame import sys from pygame.locals import * #初始化 pygame.init() size = width,h ...

  9. Django基础之模板

    Django模板系统 官方文档 常用语法 只需要记两种特殊符号: {{  }} 和 {% %} 变量相关的用{{ }},逻辑相关的用{% %}. 变量 {{ 变量名 }} 变量名由字母数字和下划线组成 ...

  10. C# 连接和操作SQL SERVER数据库

    用C#sqlserver实现增删改查http://www.worlduc.com/blog2012.aspx?bid=730767 using System.Data;using System.Dat ...