DeepNetwork---tensorflow实现
https://github.com/zle1992/Reinforcement_Learning_Game

DeepQNetwork.py
import numpy as np
import tensorflow as tf
from abc import ABCMeta, abstractmethod
np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork(object):
__metaclass__ = ABCMeta
"""docstring for DeepQNetwork"""
def __init__(self,
n_actions,
n_features,
learning_rate,
reward_decay,
e_greedy,
replace_target_iter,
memory_size,
e_greedy_increment,
output_graph,
log_dir,
):
super(DeepQNetwork, self).__init__() self.n_actions = n_actions
self.n_features = n_features
self.learning_rate=learning_rate
self.gamma=reward_decay
self.epsilon_max=e_greedy
self.replace_target_iter=replace_target_iter
self.memory_size=memory_size
self.epsilon_increment=e_greedy_increment
self.output_graph=output_graph
self.lr =learning_rate
# total learning step
self.learn_step_counter = 0
self.log_dir = log_dir self.s = tf.placeholder(tf.float32,[None]+self.n_features,name='s')
self.s_next = tf.placeholder(tf.float32,[None]+self.n_features,name='s_next') self.r = tf.placeholder(tf.float32,[None,],name='r')
self.a = tf.placeholder(tf.int32,[None,],name='a') self.q_eval = self._build_q_net(self.s, scope='eval_net', trainable=True)
self.q_next = self._build_q_net(self.s_next, scope='target_net', trainable=False) with tf.variable_scope('q_target'):
self.q_target = self.r + self.gamma * tf.reduce_max(self.q_next, axis=1, name='Qmax_s_') # shape=(None, )
with tf.variable_scope('q_eval'):
a_indices = tf.stack([tf.range(tf.shape(self.a)[0], dtype=tf.int32), self.a], axis=1)
self.q_eval_wrt_a = tf.gather_nd(params=self.q_eval, indices=a_indices) # shape=(None, )
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(tf.squared_difference(self.q_target, self.q_eval_wrt_a, name='TD_error'))
with tf.variable_scope('train'):
self._train_op = tf.train.RMSPropOptimizer(self.lr).minimize(self.loss) t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net') with tf.variable_scope("hard_replacement"):
self.target_replace_op=[tf.assign(t,e) for t,e in zip(t_params,e_params)] self.sess = tf.Session()
if self.output_graph:
tf.summary.FileWriter(self.log_dir,self.sess.graph) self.sess.run(tf.global_variables_initializer()) self.cost_his =[] @abstractmethod
def _build_q_net(self,x,scope,trainable):
raise NotImplementedError def learn(self,data): # check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self.sess.run(self.target_replace_op)
print('\ntarget_params_replaced\n') batch_memory_s = data['s'],
batch_memory_a = data['a'],
batch_memory_r = data['r'],
batch_memory_s_ = data['s_'],
_, cost = self.sess.run(
[self._train_op, self.loss],
feed_dict={
self.s: batch_memory_s,
self.a: batch_memory_a,
self.r: batch_memory_r,
self.s_next: batch_memory_s_,
})
self.cost_his.append(cost) # increasing epsilon
self.epsilon_max = self.epsilon_max + self.epsilon_increment if self.epsilon_max < self.epsilon_max else self.epsilon_max
self.learn_step_counter += 1 def choose_action(self,s):
s = s[np.newaxis,:]
aa = np.random.uniform()
#print("epsilon_max",self.epsilon_max)
if aa < self.epsilon_max:
action_value = self.sess.run(self.q_eval,feed_dict={self.s:s})
action = np.argmax(action_value)
else:
action = np.random.randint(0,self.n_actions)
return action
Memory.py
import numpy as np
np.random.seed(1)
class Memory(object):
"""docstring for Memory"""
def __init__(self,
n_actions,
n_features,
memory_size):
super(Memory, self).__init__()
self.memory_size = memory_size
self.cnt =0 self.s = np.zeros([memory_size]+n_features)
self.a = np.zeros([memory_size,])
self.r = np.zeros([memory_size,])
self.s_ = np.zeros([memory_size]+n_features) def store_transition(self,s, a, r, s_):
#logging.info('store_transition')
index = self.cnt % self.memory_size
self.s[index] = s
self.a[index] = a
self.r[index] = r
self.s_[index] =s_
self.cnt+=1 def sample(self,n):
#logging.info('sample')
#assert self.cnt>=self.memory_size,'Memory has not been fulfilled'
N = min(self.memory_size,self.cnt)
indices = np.random.choice(N,size=n)
d ={}
d['s'] = self.s[indices][0]
d['s_'] = self.s_[indices][0]
d['r'] = self.r[indices][0]
d['a'] = self.a[indices][0]
return d
主函数
import gym
import numpy as np
import tensorflow as tf from Memory import Memory
from DeepQNetwork import DeepQNetwork np.random.seed(1)
tf.set_random_seed(1) import logging # 引入logging模块
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') # logging.basicConfig函数对日志的输出格式及方式做相关配置
# 由于日志基本配置中级别设置为DEBUG,所以一下打印信息将会全部显示在控制台上 tfconfig = tf.ConfigProto()
tfconfig.gpu_options.allow_growth = True
session = tf.Session(config=tfconfig) class DeepQNetwork4CartPole(DeepQNetwork):
"""docstring for ClassName"""
def __init__(self, **kwargs):
super(DeepQNetwork4CartPole, self).__init__(**kwargs) def _build_q_net(self,x,scope,trainable):
w_initializer, b_initializer = tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) with tf.variable_scope(scope):
e1 = tf.layers.dense(inputs=x,
units=32,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.relu,
trainable=trainable)
q = tf.layers.dense(inputs=e1,
units=self.n_actions,
bias_initializer = b_initializer,
kernel_initializer=w_initializer,
activation = tf.nn.sigmoid,
trainable=trainable) return q batch_size = 64 memory_size =2000
#env = gym.make('Breakout-v0') #离散
env = gym.make('CartPole-v0') #离散 n_features= list(env.observation_space.shape)
n_actions= env.action_space.n env = env.unwrapped def run(): RL = DeepQNetwork4CartPole(
n_actions=n_actions,
n_features=n_features,
learning_rate=0.01,
reward_decay=0.9,
e_greedy=0.9,
replace_target_iter=200,
memory_size=memory_size,
e_greedy_increment=None,
output_graph=True,
log_dir = 'log/DeepQNetwork4CartPole/',
) memory = Memory(n_actions,n_features,memory_size=memory_size) step = 0
ep_r = 0
for episode in range(2000):
# initial observation
observation = env.reset() while True: # RL choose action based on observation
action = RL.choose_action(observation)
# logging.debug('action')
# print(action)
# RL take action and get_collectiot next observation and reward
observation_, reward, done, info=env.step(action) # take a random action # the smaller theta and closer to center the better
x, x_dot, theta, theta_dot = observation_
r1 = (env.x_threshold - abs(x))/env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta))/env.theta_threshold_radians - 0.5
reward = r1 + r2 memory.store_transition(observation, action, reward, observation_) if (step > 200) and (step % 5 == 0): data = memory.sample(batch_size)
RL.learn(data)
#print('step:%d----reward:%f---action:%d'%(step,reward,action))
# swap observation
observation = observation_
ep_r += reward
# break while loop when end of this episode
if(episode>700):
env.render() # render on the screen
if done:
print('episode: ', episode,
'ep_r: ', round(ep_r, 2),
' epsilon: ', round(RL.epsilon_max, 2))
ep_r = 0 break
step += 1 # end of game
print('game over')
env.destroy() def main(): run() if __name__ == '__main__':
main()
#run2()
DeepNetwork---tensorflow实现的更多相关文章
- Tensorflow 官方版教程中文版
2015年11月9日,Google发布人工智能系统TensorFlow并宣布开源,同日,极客学院组织在线TensorFlow中文文档翻译.一个月后,30章文档全部翻译校对完成,上线并提供电子书下载,该 ...
- tensorflow学习笔记二:入门基础
TensorFlow用张量这种数据结构来表示所有的数据.用一阶张量来表示向量,如:v = [1.2, 2.3, 3.5] ,如二阶张量表示矩阵,如:m = [[1, 2, 3], [4, 5, 6], ...
- 用Tensorflow让神经网络自动创造音乐
#————————————————————————本文禁止转载,禁止用于各类讲座及ppt中,违者必究————————————————————————# 前几天看到一个有意思的分享,大意是讲如何用Ten ...
- tensorflow 一些好的blog链接和tensorflow gpu版本安装
pading :SAME,VALID 区别 http://blog.csdn.net/mao_xiao_feng/article/details/53444333 tensorflow实现的各种算法 ...
- tensorflow中的基本概念
本文是在阅读官方文档后的一些个人理解. 官方文档地址:https://www.tensorflow.org/versions/r0.12/get_started/basic_usage.html#ba ...
- kubernetes&tensorflow
谷歌内部--Borg Google Brain跑在数十万台机器上 谷歌电商商品分类深度学习模型跑在1000+台机器上 谷歌外部--Kubernetes(https://github.com/kuber ...
- tensorflow学习
tensorflow安装时遇到gcc: error trying to exec 'as': execvp: No such file or directory. 截止到2016年11月13号,源码编 ...
- 【转】TensorFlow练习20: 使用深度学习破解字符验证码
验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾广告等等 . 验证码的作用是验证用户是真人还是机器人:设计理念是对人友好,对机 ...
- 【转】机器学习教程 十四-利用tensorflow做手写数字识别
模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...
- 【转】Ubuntu 16.04安装配置TensorFlow GPU版本
之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...
随机推荐
- Java线程的状态分析
线程状态 1.新建状态(New):新创建了一个线程对象. 2.就绪状态(Runnable):线程对象创建后,其他线程调用了该对象的start()方法.该状态的线程位于“可运行线程池”中,变得可运行,只 ...
- 20165336 实验一 Java开发环境的熟悉
20165336 实验一 Java开发环境的熟悉 一.实验报告封面 课程:Java程序设计 班级:1653班 姓名:康志强 学号:20165336 指导教师:娄嘉鹏 实验日期:2018年4月2日 实验 ...
- JdbcTemplate中向in语句传参
spring jdbc包提供了JdbcTemplate和它的两个兄弟SimpleJdbcTemplate和NamedParameterJdbcTemplate,我们先从JdbcTemplate入手, ...
- wordpress站内搜索结果页URL伪静态如何操作
站内搜索页面的优化一直被很多人忽略,只是按cms自带的默认设置,其实搜索结果页是一块宝藏,url重写是提升的重要一步.之前我们写过帝国CMS搜索页伪静态实现方法,那么,wordpress站内搜索结果页 ...
- OC常用控件封装
#import <Foundation/Foundation.h> #import <UIKit/UIKit.h> @interface CreateUI : NSObject ...
- docker+mysql基本搭建过程
查询镜像 [root@bms-e4e3 ~]# docker search mysql INDEX NAME DESCRIPTION STARS OFFICIAL AUTOMATED docker.i ...
- OpenFace的一些了解
1.OpenFace内4个样例代码 配置学习了两个 其一: Ubantu 基本命令 Docker 安装方式.发布网站方式.查看验证安装结果命令 Openface 基本demo 实现方式.和基本原理 其 ...
- C# install-package:"xx"已拥有为“xxx”定义的依赖项
可能 nuget自身的版本落后于适配程序包的版本 Visual Studio 2013 更新 NuGet 包管理器 Ø 前言 使用 Visual Studio 中的 NuGet 包管理器下载程序时, ...
- Windows 10正式版的历史版本
1.Windows 10 1507 初版Windows 10,代号TH1,版本号10240,发布于2015年7月. 2015年7月29日,微软正式发布了Windows 10操作系统.Windows 1 ...
- VGA线 1080P之伤 <中秋节篇>
故障:通过一台笔记本,在一台高清的电视机上使用VGA线进行视频传输,分辩率最高只能显示1600*1200,为什么不是1080P? 哎,我也很郁闷,查了相关的资料~电视机是最高支持1080P分辩率的,笔 ...