https://github.com/zle1992/Reinforcement_Learning_Game

DeepQNetwork.py
 import numpy as np
import tensorflow as tf
from abc import ABCMeta, abstractmethod
np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork(object):
__metaclass__ = ABCMeta
"""docstring for DeepQNetwork"""
def __init__(self,
n_actions,
n_features,
learning_rate,
reward_decay,
e_greedy,
replace_target_iter,
memory_size,
e_greedy_increment,
output_graph,
log_dir,
):
super(DeepQNetwork, self).__init__() self.n_actions = n_actions
self.n_features = n_features
self.learning_rate=learning_rate
self.gamma=reward_decay
self.epsilon_max=e_greedy
self.replace_target_iter=replace_target_iter
self.memory_size=memory_size
self.epsilon_increment=e_greedy_increment
self.output_graph=output_graph
self.lr =learning_rate
# total learning step
self.learn_step_counter = 0
self.log_dir = log_dir self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next') self.r = tf.placeholder(tf.float32,[None,],name='r')
self.a = tf.placeholder(tf.int32,[None,],name='a') self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False) with tf.variable_scope('q_target'):
self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
with tf.variable_scope('q_eval'):
a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope("hard_replacement"):
self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)] self.sess = tf.Session()
if self.output_graph:
tf.summary.FileWriter(self.log_dir,self.sess.graph) self.sess.run(tf.global_variables_initializer()) self.cost_his =[] @abstractmethod
def _build_q_net(self,x,scope,trainable):
raise NotImplementedError def learn(self,data): # check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self.sess.run(self.target_replace_op)
print('\ntarget_params_replaced\n') batch_memory_s = data['s'],
batch_memory_a = data['a'],
batch_memory_r = data['r'],
batch_memory_s_ = data['s_'],
_, cost = self.sess.run(
[self._train_op, self.loss],
feed_dict={
self.s: batch_memory_s,
self.a: batch_memory_a,
self.r: batch_memory_r,
self.s_next: batch_memory_s_,
})
self.cost_his.append(cost) # increasing epsilon
self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
self.learn_step_counter += 1 def choose_action(self,s):
s = s[np.newaxis,:]
aa = np.random.uniform()
#print("epsilon_max",self.epsilon_max)
if aa < self.epsilon_max:
action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
action = np.argmax(action_value)
else:
action = np.random.randint(0,self.n_actions)
return action
Memory.py
 import numpy as np
np.random.seed(1)
class Memory(object):
"""docstring for Memory"""
def __init__(self,
n_actions,
n_features,
memory_size):
super(Memory, self).__init__()
self.memory_size = memory_size
self.cnt =0 self.s = np.zeros([memory_size]+n_features)
self.a = np.zeros([memory_size,])
self.r = np.zeros([memory_size,])
self.s_ = np.zeros([memory_size]+n_features) def store_transition(self,s, a, r, s_):
#logging.info('store_transition')
index = self.cnt % self.memory_size
self.s[index] = s
self.a[index] = a
self.r[index] = r
self.s_[index] =s_
self.cnt+=1 def sample(self,n):
#logging.info('sample')
#assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
N = min(self.memory_size,self.cnt)
indices = np.random.choice(N,size=n)
d ={}
d['s'] = self.s[indices][0]
d['s_'] = self.s_[indices][0]
d['r'] = self.r[indices][0]
d['a'] = self.a[indices][0]
return d

主函数

 import gym
import numpy as np
import tensorflow as tf from Memory import Memory
from DeepQNetwork import DeepQNetwork np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork4CartPole(DeepQNetwork):
"""docstring for ClassName"""
def __init__(self, **kwargs):
super(DeepQNetwork4CartPole, self).__init__(**kwargs) def _build_q_net(self,x,scope,trainable):
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) with tf.variable_scope(scope):
e1 = tf.layers.dense(inputs=x,
units=32,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.relu,
trainable=trainable)
q = tf.layers.dense(inputs=e1,
units=self.n_actions,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.sigmoid,
trainable=trainable) return q batch_size = 64 memory_size =2000
#env = gym.make('Breakout-v0') #离散
env = gym.make('CartPole-v0') #离散 n_features= list(env.observation_space.shape)
n_actions= env.action_space.n env = env.unwrapped def run(): RL = DeepQNetwork4CartPole(
n_actions=n_actions,
n_features=n_features,
learning_rate=0.01,
reward_decay=0.9,
e_greedy=0.9,
replace_target_iter=200,
memory_size=memory_size,
e_greedy_increment=None,
output_graph=True,
log_dir = 'log/DeepQNetwork4CartPole/',
) memory = Memory(n_actions,n_features,memory_size=memory_size) step = 0
ep_r = 0
for episode in range(2000):
# initial observation
observation = env.reset() while True: # RL choose action based on observation
action = RL.choose_action(observation)
# logging.debug('action')
# print(action)
# RL take action and get_collectiot next observation and reward
observation_, reward, done, info=env.step(action) # take a random action # the smaller theta and closer to center the better
x, x_dot, theta, theta_dot = observation_
r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
reward = r1 + r2 memory.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): data = memory.sample(batch_size)
RL.learn(data)
#print('step:%d----reward:%f---action:%d'%(step,reward,action))
# swap observation
observation = observation_
ep_r += reward
# break while loop when end of this episode
if(episode>700):
env.render() # render on the screen
if done:
print('episode: ', episode,
'ep_r: ', round(ep_r, 2),
' epsilon: ', round(RL.epsilon_max, 2))
ep_r = 0 break
step += 1 # end of game
print('game over')
env.destroy() def main(): run() if __name__ == '__main__':
main()
#run2()

DeepNetwork---tensorflow实现的更多相关文章

  1. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  2. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  3. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  4. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  5. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  6. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  7. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

  8. 【转】TensorFlow练习20: 使用深度学习破解字符验证码

    验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾广告等等 . 验证码的作用是验证用户是真人还是机器人:设计理念是对人友好,对机 ...

  9. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  10. 【转】Ubuntu 16.04安装配置TensorFlow GPU版本

    之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...

随机推荐

  1. [network] IPVS / Load balancer / Linux Virtual Server

    Load Balancer IPVS: http://kb.linuxvirtualserver.org/wiki/IPVS NAT: http://kb.linuxvirtualserver.org ...

  2. [development][tcp/ip][ids] 一个简单有参考价值的库 libnids

    libhtp 中的例子, 可以通过libnids快速使用. 或者可以快速的写个sniffer. 支持三个功能 ip分片重组, tcp乱序重排, 端口扫描发现. 工程: https://github.c ...

  3. CAAnimationDelegate 代理方法没调用

    CAAnimationDelegate 代理方法没调用 应该在 addAnimation调用之前设置代理

  4. Oracle 分区表 收集统计信息 参数granularity

    GRANULARITY Determines the granularity of statistics to collect. This value is only relevant for par ...

  5. LeetCode 953 Verifying an Alien Dictionary 解题报告

    题目要求 In an alien language, surprisingly they also use english lowercase letters, but possibly in a d ...

  6. jquery.ajax与axios及定义拦截器

    首先导入jquery和axios包 jquery.ajax function reg(){ var username = $("#username").val(); var pas ...

  7. Jenkins tomcat打包启动脚本,待完善

    shell脚本 #!/bin/bashJENKINS_HOME=/usr/local/jenkinsTOMCAT_HOME=/usr/local/jenkins/tomcat-testSHUTDOWN ...

  8. 使用 HTMLTestRunner 模块生成HTML格式的测试报告文件

    1.下载HTMLTestRunner.py HTMLTestRunner 是 Python 标准库的 unittest 模块的一个扩展.它生成易于使用的 HTML 测试报告.HTMLTestRunne ...

  9. vue 常用问题

    缺少style-loader 或者 stylus-loader 等问题 在文件[package.json]分支:[devDependencies]添加 版本号: "stylus": ...

  10. 微信公众号支付开发当前URL未注册解决办法

    微信公众号支付,则需要现在微信支付商户后台,预先添加支付网址才可以.否则会出现“当前URL未注册”字样的报错. 方法/步骤 1. 报错实例如下,手机端微信调起微信支付框时弹出 2. 首先需要登录微信支 ...