https://github.com/zle1992/Reinforcement_Learning_Game

DeepQNetwork.py
 import numpy as np
import tensorflow as tf
from abc import ABCMeta, abstractmethod
np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork(object):
__metaclass__ = ABCMeta
"""docstring for DeepQNetwork"""
def __init__(self,
n_actions,
n_features,
learning_rate,
reward_decay,
e_greedy,
replace_target_iter,
memory_size,
e_greedy_increment,
output_graph,
log_dir,
):
super(DeepQNetwork, self).__init__() self.n_actions = n_actions
self.n_features = n_features
self.learning_rate=learning_rate
self.gamma=reward_decay
self.epsilon_max=e_greedy
self.replace_target_iter=replace_target_iter
self.memory_size=memory_size
self.epsilon_increment=e_greedy_increment
self.output_graph=output_graph
self.lr =learning_rate
# total learning step
self.learn_step_counter = 0
self.log_dir = log_dir self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next') self.r = tf.placeholder(tf.float32,[None,],name='r')
self.a = tf.placeholder(tf.int32,[None,],name='a') self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False) with tf.variable_scope('q_target'):
self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
with tf.variable_scope('q_eval'):
a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope("hard_replacement"):
self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)] self.sess = tf.Session()
if self.output_graph:
tf.summary.FileWriter(self.log_dir,self.sess.graph) self.sess.run(tf.global_variables_initializer()) self.cost_his =[] @abstractmethod
def _build_q_net(self,x,scope,trainable):
raise NotImplementedError def learn(self,data): # check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self.sess.run(self.target_replace_op)
print('\ntarget_params_replaced\n') batch_memory_s = data['s'],
batch_memory_a = data['a'],
batch_memory_r = data['r'],
batch_memory_s_ = data['s_'],
_, cost = self.sess.run(
[self._train_op, self.loss],
feed_dict={
self.s: batch_memory_s,
self.a: batch_memory_a,
self.r: batch_memory_r,
self.s_next: batch_memory_s_,
})
self.cost_his.append(cost) # increasing epsilon
self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
self.learn_step_counter += 1 def choose_action(self,s):
s = s[np.newaxis,:]
aa = np.random.uniform()
#print("epsilon_max",self.epsilon_max)
if aa < self.epsilon_max:
action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
action = np.argmax(action_value)
else:
action = np.random.randint(0,self.n_actions)
return action
Memory.py
 import numpy as np
np.random.seed(1)
class Memory(object):
"""docstring for Memory"""
def __init__(self,
n_actions,
n_features,
memory_size):
super(Memory, self).__init__()
self.memory_size = memory_size
self.cnt =0 self.s = np.zeros([memory_size]+n_features)
self.a = np.zeros([memory_size,])
self.r = np.zeros([memory_size,])
self.s_ = np.zeros([memory_size]+n_features) def store_transition(self,s, a, r, s_):
#logging.info('store_transition')
index = self.cnt % self.memory_size
self.s[index] = s
self.a[index] = a
self.r[index] = r
self.s_[index] =s_
self.cnt+=1 def sample(self,n):
#logging.info('sample')
#assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
N = min(self.memory_size,self.cnt)
indices = np.random.choice(N,size=n)
d ={}
d['s'] = self.s[indices][0]
d['s_'] = self.s_[indices][0]
d['r'] = self.r[indices][0]
d['a'] = self.a[indices][0]
return d

主函数

 import gym
import numpy as np
import tensorflow as tf from Memory import Memory
from DeepQNetwork import DeepQNetwork np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork4CartPole(DeepQNetwork):
"""docstring for ClassName"""
def __init__(self, **kwargs):
super(DeepQNetwork4CartPole, self).__init__(**kwargs) def _build_q_net(self,x,scope,trainable):
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) with tf.variable_scope(scope):
e1 = tf.layers.dense(inputs=x,
units=32,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.relu,
trainable=trainable)
q = tf.layers.dense(inputs=e1,
units=self.n_actions,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.sigmoid,
trainable=trainable) return q batch_size = 64 memory_size =2000
#env = gym.make('Breakout-v0') #离散
env = gym.make('CartPole-v0') #离散 n_features= list(env.observation_space.shape)
n_actions= env.action_space.n env = env.unwrapped def run(): RL = DeepQNetwork4CartPole(
n_actions=n_actions,
n_features=n_features,
learning_rate=0.01,
reward_decay=0.9,
e_greedy=0.9,
replace_target_iter=200,
memory_size=memory_size,
e_greedy_increment=None,
output_graph=True,
log_dir = 'log/DeepQNetwork4CartPole/',
) memory = Memory(n_actions,n_features,memory_size=memory_size) step = 0
ep_r = 0
for episode in range(2000):
# initial observation
observation = env.reset() while True: # RL choose action based on observation
action = RL.choose_action(observation)
# logging.debug('action')
# print(action)
# RL take action and get_collectiot next observation and reward
observation_, reward, done, info=env.step(action) # take a random action # the smaller theta and closer to center the better
x, x_dot, theta, theta_dot = observation_
r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
reward = r1 + r2 memory.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): data = memory.sample(batch_size)
RL.learn(data)
#print('step:%d----reward:%f---action:%d'%(step,reward,action))
# swap observation
observation = observation_
ep_r += reward
# break while loop when end of this episode
if(episode>700):
env.render() # render on the screen
if done:
print('episode: ', episode,
'ep_r: ', round(ep_r, 2),
' epsilon: ', round(RL.epsilon_max, 2))
ep_r = 0 break
step += 1 # end of game
print('game over')
env.destroy() def main(): run() if __name__ == '__main__':
main()
#run2()

DeepNetwork---tensorflow实现的更多相关文章

  1. Tensorflow 官方版教程中文版

    2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...

  2. tensorflow学习笔记二:入门基础

    TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...

  3. 用Tensorflow让神经网络自动创造音乐

    #————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...

  4. tensorflow 一些好的blog链接和tensorflow gpu版本安装

    pading :SAME,VALID 区别  http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...

  5. tensorflow中的基本概念

    本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...

  6. kubernetes&tensorflow

    谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...

  7. tensorflow学习

    tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...

  8. 【转】TensorFlow练习20: 使用深度学习破解字符验证码

    验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾广告等等 . 验证码的作用是验证用户是真人还是机器人:设计理念是对人友好,对机 ...

  9. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  10. 【转】Ubuntu 16.04安装配置TensorFlow GPU版本

    之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...

随机推荐

  1. [daily][archlinux][game] 几个linux下还不错的游戏

    春节在家放松, 装了几个游戏在archlinux上,玩起来还不错. 假期结束开工了, 玩的时间也会比较少,准备删掉, 记录如下: 1. 0ad 即时战略类, 类似于冷兵器时代的红警. 画面还不错, 可 ...

  2. python——二分查找算法

    从有序列表的候选区data[0:n]开始,通过对待查找的值与候选区中间值的比较,可以使候选区减少一半   二分查找: 在一段数字内,找到中间值,判断要找的值和中间值大小的比较. 如果中间值大一些,则在 ...

  3. python 全排列combinations和permutations函数

    结果为: 显然,combinations方法重点在组合,permutations方法重在排列. 还有就是,combinations和permutations返回的是对象地址,原因是在python3里面 ...

  4. Oracle shrink table

    shrink必须开启行迁移功能. alter table table_name enable row movement ; 在oracle中可以使用alter table table_name shr ...

  5. python 反爬虫策略

    1.限制IP地址单位时间的访问次数 : 分析:没有哪个常人一秒钟内能访问相同网站5次,除非是程序访问,而有这种喜好的,就剩下搜索引擎爬虫和讨厌的采集器了. 弊端:一刀切,这同样会阻止搜索引擎对网站的收 ...

  6. MovieLens电影数据分析

    下载数据包 链接:https://grouplens.org/datasets/movielens/1m/ 解压: 四个文件分别是数据介绍,电影数据表,电影评分表,用户表 进行电影数据分析 进入ipy ...

  7. js中常用的offset client screen对象

    javascript中offsetWidth.clientWidth.width.scrollWidth.clientX.screenX.offsetX.pageX offsetWidth //返回元 ...

  8. kvc原理

    KVC底层实现原理 第一步:寻找该属性有没有setsetter方法?有,就直接赋值 第二步:寻找有没有该属性带下划线的成员属性?有,就直接赋值 第三步:寻找有没有该属性的成员属性?有,就直接赋值 1. ...

  9. 基于fiddler实现本地代理完成脚本测试

    配置好fiddler以后,具体操作流程如下: 1.找到后在右边点击AutoResponder,查看,默认情况如下 勾选 2.ctrl+F 搜索 app. 关键字 ,找到后拖到右侧 3.替换本地文件 4 ...

  10. 如何卸载VMware虚拟机?

    如何卸载VMware虚拟机? 1.windows + R  打开>运行-->regedit(打开编辑注册表)-->找到HKEY_LOCAL_MACHINE-->Software ...