深度增强学习--Actor Critic
Actor Critic value-based和policy-based的结合
import sys
import gym
import pylab
import numpy as np
from keras.layers import Dense
from keras.models import Sequential
from keras.optimizers import Adam EPISODES = 1000 # A2C(Advantage Actor-Critic) agent for the Cartpole
# actor-critic算法结合了value-based和policy-based方法
class A2CAgent:
def __init__(self, state_size, action_size):
# if you want to see Cartpole learning, then change to True
self.render = True
self.load_model = False
# get size of state and action
self.state_size = state_size
self.action_size = action_size
self.value_size = 1 # These are hyper parameters for the Policy Gradient
self.discount_factor = 0.99
self.actor_lr = 0.001
self.critic_lr = 0.005 # create model for policy network
self.actor = self.build_actor()
self.critic = self.build_critic() if self.load_model:
self.actor.load_weights("./save_model/cartpole_actor.h5")
self.critic.load_weights("./save_model/cartpole_critic.h5") # approximate policy and value using Neural Network
# actor: state is input and probability of each action is output of model
def build_actor(self):#actor网络:state-->action
actor = Sequential()
actor.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
actor.add(Dense(self.action_size, activation='softmax',
kernel_initializer='he_uniform'))
actor.summary()
# See note regarding crossentropy in cartpole_reinforce.py
actor.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=self.actor_lr))
return actor # critic: state is input and value of state is output of model
def build_critic(self):#critic网络:state-->value,Q值
critic = Sequential()
critic.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
critic.add(Dense(self.value_size, activation='linear',
kernel_initializer='he_uniform'))
critic.summary()
critic.compile(loss="mse", optimizer=Adam(lr=self.critic_lr))
return critic # using the output of policy network, pick action stochastically
def get_action(self, state):
policy = self.actor.predict(state, batch_size=1).flatten()#根据actor网络预测下一步动作
return np.random.choice(self.action_size, 1, p=policy)[0] # update policy network every episode
def train_model(self, state, action, reward, next_state, done):
target = np.zeros((1, self.value_size))#(1,1)
advantages = np.zeros((1, self.action_size))#(1, 2) value = self.critic.predict(state)[0]#critic网络预测的当前q值
next_value = self.critic.predict(next_state)[0]#critic网络预测的下一个q值 '''
理解下面部分
'''
if done:
advantages[0][action] = reward - value
target[0][0] = reward
else:
advantages[0][action] = reward + self.discount_factor * (next_value) - value#acotr网络
target[0][0] = reward + self.discount_factor * next_value#critic网络 self.actor.fit(state, advantages, epochs=1, verbose=0)
self.critic.fit(state, target, epochs=1, verbose=0) if __name__ == "__main__":
# In case of CartPole-v1, maximum length of episode is 500
env = gym.make('CartPole-v1')
# get size of state and action from environment
state_size = env.observation_space.shape[0]
action_size = env.action_space.n # make A2C agent
agent = A2CAgent(state_size, action_size)
scores, episodes = [], [] for e in range(EPISODES):
done = False
score = 0
state = env.reset()
state = np.reshape(state, [1, state_size]) while not done:
if agent.render:
env.render() action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
# if an action make the episode end, then gives penalty of -100
reward = reward if not done or score == 499 else -100 agent.train_model(state, action, reward, next_state, done)#每执行一次action训练一次 score += reward
state = next_state if done:
# every episode, plot the play time
score = score if score == 500.0 else score + 100
scores.append(score)
episodes.append(e)
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/cartpole_a2c.png")
print("episode:", e, " score:", score) # if the mean of scores of last 10 episode is bigger than 490
# stop training
if np.mean(scores[-min(10, len(scores)):]) > 490:
sys.exit() # save the model
if e % 50 == 0:
agent.actor.save_weights("./save_model/cartpole_actor.h5")
agent.critic.save_weights("./save_model/cartpole_critic.h5")
深度增强学习--Actor Critic的更多相关文章
- 深度增强学习--DDPG
DDPG DDPG介绍2 ddpg输出的不是行为的概率, 而是具体的行为, 用于连续动作 (continuous action) 的预测 公式推导 推导 代码实现的gym的pendulum游戏,这个游 ...
- 深度增强学习--A3C
A3C 它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所 ...
- 深度增强学习--DPPO
PPO DPPO介绍 PPO实现 代码DPPO
- 深度增强学习--DQN的变形
DQN的变形 double DQN prioritised replay dueling DQN
- 深度增强学习--Policy Gradient
前面都是value based的方法,现在看一种直接预测动作的方法 Policy Based Policy Gradient 一个介绍 karpathy的博客 一个推导 下面的例子实现的REINFOR ...
- 深度增强学习--Deep Q Network
从这里开始换个游戏演示,cartpole游戏 Deep Q Network 实例代码 import sys import gym import pylab import random import n ...
- 常用增强学习实验环境 II (ViZDoom, Roboschool, TensorFlow Agents, ELF, Coach等) (转载)
原文链接:http://blog.csdn.net/jinzhuojun/article/details/78508203 前段时间Nature上发表的升级版Alpha Go - AlphaGo Ze ...
- 马里奥AI实现方式探索 ——神经网络+增强学习
[TOC] 马里奥AI实现方式探索 --神经网络+增强学习 儿时我们都曾有过一个经典游戏的体验,就是马里奥(顶蘑菇^v^),这次里约奥运会闭幕式,日本作为2020年东京奥运会的东道主,安倍最后也已经典 ...
- 增强学习 | AlphaGo背后的秘密
"敢于尝试,才有突破" 2017年5月27日,当今世界排名第一的中国棋手柯洁与AlphaGo 2.0的三局对战落败.该事件标志着最新的人工智能技术在围棋竞技领域超越了人类智能,借此 ...
随机推荐
- 网络基础(osi、协议)
*互联网协议 人和人沟通需要一套共同的标准,英语就是普遍的一种,计算机如果需要进行联网互通,也需要一种统一的标准,如果所有的计算机都遵守这种标准,就会实现网络的互联. 1.一系列统一的标准,这些标准称 ...
- Android USB Camera(1) : 调试记录【转】
转自:http://blog.csdn.net/eternity9255/article/details/53069037 版权声明:本文为博主原创文章,未经博主允许不得转载. 目录(?)[-] 前言 ...
- 某dp题
[NOI联考by ysy]庆典 2016年6月17日1,1040 [题目描述] 战狂在昌和帝国的首都法法城召开了庆典,向一万名最杰出的士兵分发了用魔法猪做的猪肉饺子,士兵们吃了猪肉饺子后,战斗力大幅提 ...
- python变现实现新浪微博登陆
新浪微博的登陆现在是越来越那个了,以前的模拟浏览器登陆新浪微博貌似也越来不管用了 登陆信息由以前的form变成了现在javascript,javascript的加载居然用了一个javascript的函 ...
- KVM(一)简介及安装
1. KVM 介绍 1.0 虚拟化简史 其中,KVM 全称是 基于内核的虚拟机(Kernel-based Virtual Machine),它是一个 Linux 的一个内核模块,该内核模块使得 Lin ...
- 【转载】Window 窗口层次关系
相信在Window 下面编程的很多兄弟们都不是很清楚Window 中窗口的层次关系是怎么样的,这个东西很久已经研究过一下,后来又忘记了,今天又一次遇到了这个问题,所以便整理一下.下面就说说Window ...
- LPD Office插件使用指南
LPD Office插件已经发布至Azure上,您可以在本机Outlook和Office Online使用该插件 一:在Outlook中使用 LPD Office插件 打开Outlook应用,并点击“ ...
- [python] 如何将unicode字符串转换为中文
答案:(http://stackoverflow.com/) ps:这个网站解决了我好多问题啊,大家多上 >>>s='\u9648\u4f1f\u9706\u5176\u5b9e\u ...
- STL心得
熟悉c++版算法竞赛程序框架 理解变量引用的原理 熟练掌握string和stringstream 熟练掌握c++结构体的定义和使用,包括构造函数和静态成员变量 了解常见的可重载运算符,包括四则运算,赋 ...
- 使用Nginx+uWSGI部署Django项目
1.linux安装python3环境 参考链接:https://www.cnblogs.com/zzqit/p/10087680.html 2.安装uwsgi pip3 install uwsgi l ...