基于tensorflow的躲避障碍物的ai训练
import pygame
import random
from pygame.locals import *
import numpy as np
from collections import deque
import tensorflow as tf # http://blog.topspeedsnail.com/archives/10116
import cv2 # http://blog.topspeedsnail.com/archives/4755
score = 0
BLACK = (0, 0, 0)
WHITE = (255, 255, 255) SCREEN_SIZE = [320, 400]
BAR_SIZE = [100, 5]
BALL_SIZE = [20, 20] # 神经网络的输出
MOVE_STAY = [1, 0, 0]
MOVE_LEFT = [0, 1, 0]
MOVE_RIGHT = [0, 0, 1] class Game(object):
def __init__(self):
pygame.init()
self.clock = pygame.time.Clock()
self.screen = pygame.display.set_mode(SCREEN_SIZE)
pygame.display.set_caption('Simple Game') self.ball_pos_x = SCREEN_SIZE[0] // 2 - BALL_SIZE[0] / 2
self.ball_pos_y = SCREEN_SIZE[1] // 2 - BALL_SIZE[1] / 2 self.ball_dir_x = -1 # -1 = left 1 = right
self.ball_dir_y = -1 # -1 = up 1 = down
self.ball_pos = pygame.Rect(self.ball_pos_x, self.ball_pos_y, BALL_SIZE[0], BALL_SIZE[1]) self.bar_pos_x = SCREEN_SIZE[0] // 2 - BAR_SIZE[0] // 2
self.bar_pos = pygame.Rect(self.bar_pos_x, SCREEN_SIZE[1] - BAR_SIZE[1], BAR_SIZE[0], BAR_SIZE[1]) self.ball2_pos_x = SCREEN_SIZE[0] // 2 - BALL_SIZE[0] / 2
self.ball2_pos_y = SCREEN_SIZE[1] // 2 - BALL_SIZE[1] / 2 self.ball2_dir_x = -1 # -1 = left 1 = right
self.ball2_dir_y = -1 # -1 = up 1 = down
self.ball2_pos = pygame.Rect(self.ball2_pos_x, self.ball2_pos_y, BALL_SIZE[0], BALL_SIZE[1]) # action是MOVE_STAY、MOVE_LEFT、MOVE_RIGHT
# ai控制棒子左右移动;返回游戏界面像素数和对应的奖励。(像素->奖励->强化棒子往奖励高的方向移动)
def step(self, action): if action == MOVE_LEFT:
self.bar_pos_x = self.bar_pos_x - 2
elif action == MOVE_RIGHT:
self.bar_pos_x = self.bar_pos_x + 2
else:
pass
if self.bar_pos_x < 0:
self.bar_pos_x = 0
if self.bar_pos_x > SCREEN_SIZE[0] - BAR_SIZE[0]:
self.bar_pos_x = SCREEN_SIZE[0] - BAR_SIZE[0] self.screen.fill(BLACK)
self.bar_pos.left = self.bar_pos_x
pygame.draw.rect(self.screen, WHITE, self.bar_pos) # if random.randint(0, 2) < 1:
# self.ball_pos.left += self.ball_dir_x * random.randint(2, 10)
# else:
# self.ball2_pos.left -= self.ball2_dir_x * random.randint(2, 10)
self.ball_pos.left += self.ball_dir_x * random.randint(2, 10)
self.ball_pos.bottom += self.ball_dir_y * random.randint(2, 10)
# if self.ball_pos.left < 0:
# self.ball_pos.left = 1
#
# if self.ball_pos.left > 320:
# self.ball_pos.left = 305
pygame.draw.rect(self.screen, WHITE, self.ball_pos) if self.ball_pos.top <= 0 or self.ball_pos.bottom >= (SCREEN_SIZE[1] - BAR_SIZE[1] + 1):
self.ball_dir_y = self.ball_dir_y * -1
if self.ball_pos.left <= 0 or self.ball_pos.right >= (SCREEN_SIZE[0]):
self.ball_dir_x = self.ball_dir_x * -1 # if random.randint(0, 2) < 1:
# self.ball2_pos.left += self.ball2_dir_x * random.randint(2, 10)
# else:
# self.ball2_pos.left -= self.ball2_dir_x * random.randint(2, 10)
self.ball2_pos.left += self.ball2_dir_x * random.randint(2, 10)
self.ball2_pos.bottom += self.ball2_dir_y * random.randint(2, 10) # if self.ball2_pos.left < 0:
# self.ball2_pos.left = 1
# if self.ball2_pos.left > 320:
# self.ball2_pos.left = 305 pygame.draw.rect(self.screen, WHITE, self.ball2_pos) if self.ball2_pos.top <= 0 or self.ball2_pos.bottom >= (SCREEN_SIZE[1] - BAR_SIZE[1] + 1):
self.ball2_dir_y = self.ball2_dir_y * -1
if self.ball2_pos.left <= 0 or self.ball2_pos.right >= (SCREEN_SIZE[0]):
self.ball2_dir_x = self.ball2_dir_x * -1 reward = 0 if (self.bar_pos.top <= self.ball_pos.bottom and (self.bar_pos.left < self.ball_pos.right and self.bar_pos.right > self.ball_pos.left)) or (self.bar_pos.top <= self.ball2_pos.bottom and (self.bar_pos.left < self.ball2_pos.right and self.bar_pos.right > self.ball2_pos.left)) :
reward = - 10 # 击中惩罚
score = +1
print(score)
elif self.bar_pos.top <= self.ball_pos.bottom and (
self.bar_pos.left > self.ball_pos.right or self.bar_pos.right < self.ball_pos.left):
reward = +1 # 躲避奖励 # 获得游戏界面像素
screen_image = pygame.surfarray.array3d(pygame.display.get_surface())
pygame.display.update()
# 返回游戏界面像素和对应的奖励
return reward, screen_image # learning_rate
LEARNING_RATE = 0.99
# 更新梯度
INITIAL_EPSILON = 1.0
FINAL_EPSILON = 0.05
# 测试观测次数
EXPLORE = 500000
OBSERVE = 50000
# 存储过往经验大小
REPLAY_MEMORY = 500000 BATCH = 100 output = 3 # 输出层神经元数。代表3种操作-MOVE_STAY:[1, 0, 0] MOVE_LEFT:[0, 1, 0] MOVE_RIGHT:[0, 0, 1]
input_image = tf.placeholder("float", [None, 80, 100, 4]) # 游戏像素
action = tf.placeholder("float", [None, output]) # 操作 # 定义CNN-卷积神经网络 参考:http://blog.topspeedsnail.com/archives/10451
def convolutional_neural_network(input_image):
weights = {'w_conv1': tf.Variable(tf.zeros([8, 8, 4, 32])),
'w_conv2': tf.Variable(tf.zeros([4, 4, 32, 64])),
'w_conv3': tf.Variable(tf.zeros([3, 3, 64, 64])),
'w_fc4': tf.Variable(tf.zeros([3456, 784])),
'w_out': tf.Variable(tf.zeros([784, output]))} biases = {'b_conv1': tf.Variable(tf.zeros([32])),
'b_conv2': tf.Variable(tf.zeros([64])),
'b_conv3': tf.Variable(tf.zeros([64])),
'b_fc4': tf.Variable(tf.zeros([784])),
'b_out': tf.Variable(tf.zeros([output]))} conv1 = tf.nn.relu(
tf.nn.conv2d(input_image, weights['w_conv1'], strides=[1, 4, 4, 1], padding="VALID") + biases['b_conv1'])
conv2 = tf.nn.relu(
tf.nn.conv2d(conv1, weights['w_conv2'], strides=[1, 2, 2, 1], padding="VALID") + biases['b_conv2'])
conv3 = tf.nn.relu(
tf.nn.conv2d(conv2, weights['w_conv3'], strides=[1, 1, 1, 1], padding="VALID") + biases['b_conv3'])
conv3_flat = tf.reshape(conv3, [-1, 3456])
fc4 = tf.nn.relu(tf.matmul(conv3_flat, weights['w_fc4']) + biases['b_fc4']) output_layer = tf.matmul(fc4, weights['w_out']) + biases['b_out']
return output_layer # 深度强化学习入门: https://www.nervanasys.com/demystifying-deep-reinforcement-learning/
# 训练神经网络
def train_neural_network(input_image):
predict_action = convolutional_neural_network(input_image) argmax = tf.placeholder("float", [None, output])
gt = tf.placeholder("float", [None]) action = tf.reduce_sum(tf.multiply(predict_action, argmax), reduction_indices=1)
cost = tf.reduce_mean(tf.square(action - gt))
optimizer = tf.train.AdamOptimizer(1e-6).minimize(cost) game = Game()
D = deque() _, image = game.step(MOVE_STAY)
# 转换为灰度值
image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
# 转换为二值
ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
input_image_data = np.stack((image, image, image, image), axis=2) with tf.Session() as sess:
sess.run(tf.initialize_all_variables()) saver = tf.train.Saver() n = 0
epsilon = INITIAL_EPSILON
while True:
action_t = predict_action.eval(feed_dict={input_image: [input_image_data]})[0] argmax_t = np.zeros([output], dtype=np.int)
if (random.random() <= INITIAL_EPSILON):
maxIndex = random.randrange(output)
else:
maxIndex = np.argmax(action_t)
argmax_t[maxIndex] = 1
if epsilon > FINAL_EPSILON:
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE # for event in pygame.event.get(): macOS需要事件循环,否则白屏
# if event.type == QUIT:
# pygame.quit()
# sys.exit()
reward, image = game.step(list(argmax_t)) image = cv2.cvtColor(cv2.resize(image, (100, 80)), cv2.COLOR_BGR2GRAY)
ret, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
image = np.reshape(image, (80, 100, 1))
input_image_data1 = np.append(image, input_image_data[:, :, 0:3], axis=2) D.append((input_image_data, argmax_t, reward, input_image_data1)) if len(D) > REPLAY_MEMORY:
D.popleft() if n > OBSERVE:
minibatch = random.sample(D, BATCH)
input_image_data_batch = [d[0] for d in minibatch]
argmax_batch = [d[1] for d in minibatch]
reward_batch = [d[2] for d in minibatch]
input_image_data1_batch = [d[3] for d in minibatch] gt_batch = [] out_batch = predict_action.eval(feed_dict={input_image: input_image_data1_batch}) for i in range(0, len(minibatch)):
gt_batch.append(reward_batch[i] + LEARNING_RATE * np.max(out_batch[i])) optimizer.run(feed_dict={gt: gt_batch, argmax: argmax_batch, input_image: input_image_data_batch}) input_image_data = input_image_data1
n = n + 1 if n % 10000 == 0:
saver.save(sess, 'game.cpk', global_step=n) # 保存模型 print(n, "epsilon:", epsilon, " ", "action:", maxIndex, " ", "reward:", reward) train_neural_network(input_image)
基于tensorflow的躲避障碍物的ai训练的更多相关文章
- AI学习---基于TensorFlow的案例[实现线性回归的训练]
线性回归原理复习 1)构建模型 |_> y = w1x1 + w2x2 + -- + wnxn + b 2)构造损失函数 | ...
- 云原生的弹性 AI 训练系列之一:基于 AllReduce 的弹性分布式训练实践
引言 随着模型规模和数据量的不断增大,分布式训练已经成为了工业界主流的 AI 模型训练方式.基于 Kubernetes 的 Kubeflow 项目,能够很好地承载分布式训练的工作负载,业已成为了云原生 ...
- 手写数字识别 ----在已经训练好的数据上根据28*28的图片获取识别概率(基于Tensorflow,Python)
通过: 手写数字识别 ----卷积神经网络模型官方案例详解(基于Tensorflow,Python) 手写数字识别 ----Softmax回归模型官方案例详解(基于Tensorflow,Pytho ...
- 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(二)
前言 已完成数据预处理工作,具体参照: 基于TensorFlow Object Detection API进行迁移学习训练自己的人脸检测模型(一) 设置配置文件 新建目录face_faster_rcn ...
- 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)
如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...
- ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]
ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...
- 基于TensorFlow Serving的深度学习在线预估
一.前言 随着深度学习在图像.语言.广告点击率预估等各个领域不断发展,很多团队开始探索深度学习技术在业务层面的实践与应用.而在广告CTR预估方面,新模型也是层出不穷: Wide and Deep[1] ...
- 大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app
大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app ( 本文内容为melodyWxy原作,git地址:https://github.com/melodyWx ...
- 在OpenShift平台上验证NVIDIA DGX系统的分布式多节点自动驾驶AI训练
在OpenShift平台上验证NVIDIA DGX系统的分布式多节点自动驾驶AI训练 自动驾驶汽车的深度神经网络(DNN)开发是一项艰巨的工作.本文验证了DGX多节点,多GPU,分布式训练在DXC机器 ...
随机推荐
- JMeter http(s)测试脚本录制器的使用
JMeter http(s)测试脚本录制器的使用 by:授客 QQ:1033553122 http(s) Test Script Recorder允许Jmeter在你使用普通浏览器浏览web应用时,拦 ...
- 《Inside C#》笔记(二) 初识C#
一 程序的编译.构成 a) 编写C#代码一般用VS,但作者在这儿介绍了使用记事本编写C#代码并编译运行的过程,以便对VS有更深入的认识. 用记事本编写C#代码后,修改文本文件的后缀为.cs,然后用cs ...
- python txt文件数据转excel
txt content: perf.txt 2018-11-12 16:48:58 time: 16:48:58 load average: 0.62, 0.54, 0.56 mosquitto CP ...
- (其他)最常用的15大Eclipse开发快捷键技巧
转自CSDNJava我人生(陈磊兴) 原文出处 引言 做java开发的,经常会用Eclipse或者MyEclise集成开发环境,一些实用的Eclipse快捷键和使用技巧,可以在平常开发中节约出很多 ...
- tinypng upload一键压缩上传工具,告别人肉
地址 项目地址:tinypng-upload 有兴趣的可以玩一玩,因为平时经常会用到图片压缩,上传,如果你也觉得很繁琐的话,这个将会解决你的痛点. 关于 tinypng-upload 这是一个基于 e ...
- 洗礼灵魂,修炼python(21)--自定义函数(2)—函数文档,doctest模块,形参,实参,默认参数,关键字参数,收集参数,位置参数
函数文档 1.什么是函数文档: 就是放在函数体之前的一段说明,其本身是一段字符串,一个完整的函数需要带有函数文档,这样利于他人阅读,方便理解此函数的作用,能做什么运算 2.怎么查看函数文档: func ...
- MDX 脚本语句 -- Scope
在多维表达式 (MDX) 中,下列语句用于管理 MDX 脚本中的上下文.作用域和流控制. 主题 说明 calculate语句 计算子多维数据集,还可以确定子多维数据集中所包含的求解次序 case语句 ...
- GUI_鼠标事件
所有的组件都有鼠标和键盘监听器 import java.awt.Button; import java.awt.FlowLayout; import java.awt.Frame; import ja ...
- CF 633 E. Binary Table
题目链接 题目大意:给定一个棋盘,棋盘上有0或1,你可以将一整行取反或者一整列取反,要使得最后剩的1最少.\((1\le n\le 20,1\le m\le 100000)\). 一个容易想到的思路就 ...
- [NOIP2018]旅行
嘟嘟嘟 鉴于一些知道的人所知道的,不知道的人所不知道的原因,我来发NOIPday2T1的题解了. \(O(n ^ 2)\)的做法自然很暴力,枚举断边断环为链就行了. 所以我是来讲\(O(nlogn)\ ...