从这里开始换个游戏演示,cartpole游戏

Deep Q Network

实例代码

 import sys
import gym
import pylab
import random
import numpy as np
from collections import deque
from keras.layers import Dense
from keras.optimizers import Adam
from keras.models import Sequential EPISODES = 300 # DQN Agent for the Cartpole
# it uses Neural Network to approximate q function,使用神经网络近似q-learning的q函数
# and experience replay memory & fixed target q network
class DQNAgent:
def __init__(self, state_size, action_size):
# if you want to see Cartpole learning, then change to True
self.render = True
self.load_model = False # get size of state and action
self.state_size = state_size
self.action_size = action_size # These are hyper parameters for the DQN
self.discount_factor = 0.99
self.learning_rate = 0.001
self.epsilon = 1.0
self.epsilon_decay = 0.999
self.epsilon_min = 0.01
self.batch_size = 64
self.train_start = 1000
# create replay memory using deque
self.memory = deque(maxlen=2000) # create main model and target model
self.model = self.build_model()
self.target_model = self.build_model() # initialize target model
self.update_target_model() if self.load_model:
self.model.load_weights("./save_model/cartpole_dqn.h5") # approximate Q function using Neural Network
# state is input and Q Value of each action is output of network
def build_model(self):
model = Sequential()
model.add(Dense(24, input_dim=self.state_size, activation='relu',
kernel_initializer='he_uniform'))
model.add(Dense(24, activation='relu',
kernel_initializer='he_uniform'))
model.add(Dense(self.action_size, activation='linear',
kernel_initializer='he_uniform'))
model.summary()
model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
return model # after some time interval update the target model to be same with model
def update_target_model(self):
self.target_model.set_weights(self.model.get_weights()) # get action from model using epsilon-greedy policy
def get_action(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
else:
q_value = self.model.predict(state)#2,q(s,a),利用模型预测不同action的q值,选大的作为下一action
return np.argmax(q_value[0]) # save sample <s,a,r,s'> to the replay memory
def append_sample(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay # pick samples randomly from replay memory (with batch_size)
def train_model(self):
if len(self.memory) < self.train_start:
return
import pdb; pdb.set_trace()
batch_size = min(self.batch_size, len(self.memory))
mini_batch = random.sample(self.memory, batch_size)#64list
#(array([[-0.04263461, -0.00657423, 0.00506589, -0.00200269]]), 0, 1.0, array([[-0.04276609, -0.20176846, 0.00502584, 0.29227427]]), False) update_input = np.zeros((batch_size, self.state_size))
update_target = np.zeros((batch_size, self.state_size))
action, reward, done = [], [], [] for i in range(self.batch_size):
update_input[i] = mini_batch[i][0]
action.append(mini_batch[i][1])
reward.append(mini_batch[i][2])
update_target[i] = mini_batch[i][3]
done.append(mini_batch[i][4]) target = self.model.predict(update_input)#(64,2)
target_val = self.target_model.predict(update_target)#(64, 2) for i in range(self.batch_size):
# Q Learning: get maximum Q value at s' from target model
if done[i]:
target[i][action[i]] = reward[i]
else:
target[i][action[i]] = reward[i] + self.discount_factor * (
np.amax(target_val[i]))#off-policy 更新 # and do the model fit!
self.model.fit(update_input, target, batch_size=self.batch_size,
epochs=1, verbose=0) if __name__ == "__main__":
# In case of CartPole-v1, maximum length of episode is 500
env = gym.make('CartPole-v1')
# get size of state and action from environment
state_size = env.observation_space.shape[0]#
action_size = env.action_space.n# agent = DQNAgent(state_size, action_size) scores, episodes = [], [] for e in range(EPISODES):
done = False
score = 0
state = env.reset()
state = np.reshape(state, [1, state_size]) while not done:
if agent.render:
env.render() # get action for the current state and go one step in environment
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = np.reshape(next_state, [1, state_size])
# if an action make the episode end, then gives penalty of -100
reward = reward if not done or score == 499 else -100 # save the sample <s, a, r, s'> to the replay memory
agent.append_sample(state, action, reward, next_state, done)
# every time step do the training
agent.train_model()
score += reward
state = next_state if done:
# every episode update the target model to be same with model
agent.update_target_model() # every episode, plot the play time
score = score if score == 500 else score + 100
scores.append(score)
episodes.append(e)
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/cartpole_dqn.png")
print("episode:", e, " score:", score, " memory length:",
len(agent.memory), " epsilon:", agent.epsilon) # if the mean of scores of last 10 episode is bigger than 490
# stop training
if np.mean(scores[-min(10, len(scores)):]) > 490:
sys.exit() # save the model
if e % 50 == 0:
agent.model.save_weights("./save_model/cartpole_dqn.h5")

深度增强学习--Deep Q Network的更多相关文章

  1. AlphaGo的前世今生(一)Deep Q Network and Game Search Tree:Road to AI Revolution

    这一个专题将会是有关AlphaGo的前世今生以及其带来的AI革命,总共分成三节.本人水平有限,如有错误还望指正.如需转载,须征得本人同意. Road to AI Revolution(通往AI革命之路 ...

  2. 强化学习系列之:Deep Q Network (DQN)

    文章目录 [隐藏] 1. 强化学习和深度学习结合 2. Deep Q Network (DQN) 算法 3. 后续发展 3.1 Double DQN 3.2 Prioritized Replay 3. ...

  3. Deep Q Network(DQN)原理解析

    1. 前言 在前面的章节中我们介绍了时序差分算法(TD)和Q-Learning,当状态和动作空间是离散且维数不高时可使用Q-Table储存每个状态动作对的Q值,而当状态和动作空间是高维连续时,使用Q- ...

  4. 【转】【强化学习】Deep Q Network(DQN)算法详解

    原文地址:https://blog.csdn.net/qq_30615903/article/details/80744083 DQN(Deep Q-Learning)是将深度学习deeplearni ...

  5. 深度增强学习--DDPG

    DDPG DDPG介绍2 ddpg输出的不是行为的概率, 而是具体的行为, 用于连续动作 (continuous action) 的预测 公式推导 推导 代码实现的gym的pendulum游戏,这个游 ...

  6. 深度增强学习--A3C

    A3C 它会创建多个并行的环境, 让多个拥有副结构的 agent 同时在这些并行环境上更新主结构中的参数. 并行中的 agent 们互不干扰, 而主结构的参数更新受到副结构提交更新的不连续性干扰, 所 ...

  7. 深度增强学习--Actor Critic

    Actor Critic value-based和policy-based的结合 实例代码 import sys import gym import pylab import numpy as np ...

  8. 深度增强学习--Policy Gradient

    前面都是value based的方法,现在看一种直接预测动作的方法 Policy Based Policy Gradient 一个介绍 karpathy的博客 一个推导 下面的例子实现的REINFOR ...

  9. 深度增强学习--DPPO

    PPO DPPO介绍 PPO实现 代码DPPO

随机推荐

  1. [Leetcode Week10]Minimum Time Difference

    Minimum Time Difference 题解 原创文章,拒绝转载 题目来源:https://leetcode.com/problems/minimum-time-difference/desc ...

  2. Swift, Playgrounds, and XCPlayground

    http://www.codeschool.com/blog/2014/12/12/swift-playgrounds-xcplayground/ Swift, Playgrounds, and XC ...

  3. Linux-进程间通信(二): FIFO

    1. FIFO: FIFO也被成为命名管道,因其通过路径关系绑定,可以用于任意进程间通信,而普通无名管道只能用于有共同祖先的进行直接通信; 命名管道也是半双工的,open管道的时候不要以读写方式打开, ...

  4. 解决Eclipse明明有错误,却不能显示错误红叉的方法,eclipse不能显示错误

    出现这情况的原因是因为java文件的错误太多,eclipse停止编译.解决方法如下 1.勾选自动编译功能 2.clean工程 3.取消“abort build when build path erro ...

  5. 【SQL】宿主语言接口

    一般情况下,SQL语句是嵌套在宿主语言(如C语言)中的.有两种嵌套方式: 1.调用层接口(CLI):提供一些库,库中的函数和方法实现SQL的调用 2.直接嵌套SQL:在代码中嵌套SQL语句,提交给预处 ...

  6. Spring boot 文件路径读取异常

    在开发代码中,有一段需要获取resources目录下的一个配置文件(这里写作test.xml). 这段代码在ide中没有任何问题,但是一打成jar包发布到线上,这段代码就会报找不到对应文件的错误. 按 ...

  7. MATLAB多项式及多项式拟合

    多项式均表示为数组形式,数组元素为多项式降幂系数 1.      polyval函数 求多项式在某一点或某几个点的值. p = [1,1,1];%x^2+x+1 x = [-1,0,1];y = po ...

  8. 非负权值有向图上的单源最短路径算法之Dijkstra算法

    问题的提法是:给定一个没有负权值的有向图和其中一个点src作为源点(source),求从点src到其余个点的最短路径及路径长度.求解该问题的算法一般为Dijkstra算法. 假设图顶点个数为n,则针对 ...

  9. logging 日志两种使用方法(转)

    下面我们使用代码logging的代码来说明: 使用baseConfig()函数对 logging进行 简单的 配置: import logging; # 使用baseConfig()函数,可选参数有f ...

  10. 洛谷 P1577 切绳子【二分答案】

    题目描述 有N条绳子,它们的长度分别为Li.如果从它们中切割出K条长度相同的 绳子,这K条绳子每条最长能有多长?答案保留到小数点后2位. 输入输出格式 输入格式: 第一行两个整数N和K,接下来N行,描 ...