强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练!中,我们让agent“简陋地”训练了起来,经过了耗费时间的10万局游戏过后,却效果平平,尤其是初始状态的数值表现和预期相差不小。我想主要原因就是没有采用等价局面同步更新的方法,导致数据利用率较低。等价局面有7个,分别是:旋转90°,旋转180°,旋转270°,水平翻转,垂直翻转,旋转90°+水平翻转,旋转90°+垂直翻转,如下图所示。另外,在生成等价局面的同时,也要生成等价的动作,这样才能实现完整的Q值更新。

步骤1:写旋转和翻转函数

def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
list_ = list(array)
list_[:] = map(list,zip(*list_[::-1]))
return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]] def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
array = array_.copy()
n = int(np.floor(len(array)/2))
if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
for i in range(n):
temp = array[i].copy()
array[i] = array[-i-1].copy()
array[-i-1] = temp
elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
for i in range(n):
temp = array[:,i].copy()
array[:,i] = array[:,-i-1]
array[:,-i-1] = temp
return array

步骤2:写生成等价局面及等价动作的函数

函数名为 genEqualStateAndAction(state, action),定义在 Agent() 类中。

def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
state, action = state_.copy(), action_
equalStates, equalActions = [], [] # 原局面
equalStates.append(state)
equalActions.append(action) # 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转180°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(2):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转270°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(3):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) return equalStates, equalActions

细心的读者可能会发问了:你这生成等价局面不去重的么?是的,不去重了。原因之一是如果要去重,那么要比对大量的np.array,实现起来较麻烦,可能会增加很多代码时间;原因之二是对重复的局面多次更新,只是不符合逻辑,但应该没有副作用:毕竟只要数据够多,最后Q表中的值都会收敛到一个值,而重复出现次数多的局面只是收敛得更快罢了。

步骤3:修改Agent()中的相关代码

需要修改方法 addNewState(self, env_, currentMove) 和方法 updateQtable(self, env_, currentMove, done_),整体代码如下:

import gym
import random
import time
import numpy as np # 查看所有已注册的环境
# from gym import envs
# print(envs.registry.all()) def str2tuple(string): # Input: '(1,1)'
string2list = list(string)
return ( int(string2list[1]), int(string2list[4]) ) # Output: (1,1) def rotate(array): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
list_ = list(array)
list_[:] = map(list,zip(*list_[::-1]))
return np.array(list_) # Output: np.array [[7,4,1],[8,5,2],[9,6,3]] def flip(array_, direction): # Input: np.array [[1,2,3],[4,5,6],[7,8,9]]
array = array_.copy()
n = int(np.floor(len(array)/2))
if direction == 'vertical': # Output: np.array [[7,8,9],[4,5,6],[1,2,3]]
for i in range(n):
temp = array[i].copy()
array[i] = array[-i-1].copy()
array[-i-1] = temp
elif direction == 'horizon': # Output: np.array [[3,2,1],[6,5,4],[9,8,7]]
for i in range(n):
temp = array[:,i].copy()
array[:,i] = array[:,-i-1]
array[:,-i-1] = temp
return array class Game():
def __init__(self, env):
self.INTERVAL = 0 # 行动间隔
self.RENDER = False # 是否显示游戏过程
self.first = 'blue' if random.random() > 0.5 else 'red' # 随机先后手
self.currentMove = self.first
self.env = env
self.agent = Agent() def switchMove(self): # 切换行动玩家
move = self.currentMove
if move == 'blue': self.currentMove = 'red'
elif move == 'red': self.currentMove = 'blue' def newGame(self): # 新建游戏
self.first = 'blue' if random.random() > 0.5 else 'red'
self.currentMove = self.first
self.env.reset()
self.agent.reset() def run(self): # 玩一局游戏
self.env.reset() # 在第一次step前要先重置环境,不然会报错
while True:
print(f'--currentMove: {self.currentMove}--')
self.agent.updateQtable(self.env, self.currentMove, False) if self.currentMove == 'blue':
self.agent.lastState_blue = self.env.state.copy()
elif self.currentMove == 'red':
self.agent.lastState_red = self.agent.overTurn(self.env.state) # 红方视角需将状态翻转 action = self.agent.epsilon_greedy(self.env, self.currentMove)
if self.currentMove == 'blue':
self.agent.lastAction_blue = action['pos']
elif self.currentMove == 'red':
self.agent.lastAction_red = action['pos'] state, reward, done, info = self.env.step(action)
if done:
self.agent.lastReward_blue = reward
self.agent.lastReward_red = -1 * reward
self.agent.updateQtable(self.env, self.currentMove, True)
else:
if self.currentMove == 'blue':
self.agent.lastReward_blue = reward
elif self.currentMove == 'red':
self.agent.lastReward_red = -1 * reward if self.RENDER: self.env.render()
self.switchMove()
time.sleep(self.INTERVAL)
if done:
self.newGame()
if self.RENDER: self.env.render()
time.sleep(self.INTERVAL)
break class Agent():
def __init__(self):
self.Q_table = {}
self.EPSILON = 0.05
self.ALPHA = 0.5
self.GAMMA = 1 # 折扣因子
self.lastState_blue = None
self.lastAction_blue = None
self.lastReward_blue = None
self.lastState_red = None
self.lastAction_red = None
self.lastReward_red = None def reset(self):
self.lastState_blue = None
self.lastAction_blue = None
self.lastReward_blue = None
self.lastState_red = None
self.lastAction_red = None
self.lastReward_red = None def getEmptyPos(self, state): # 返回空位的坐标
action_space = []
for i, row in enumerate(state):
for j, one in enumerate(row):
if one == 0: action_space.append((i,j))
return action_space def randomAction(self, env_, mark): # 随机选择空格动作
actions = self.getEmptyPos(env_)
action_pos = random.choice(actions)
action = {'mark':mark, 'pos':action_pos}
return action def overTurn(self, state): # 翻转状态
state_ = state.copy()
for i, row in enumerate(state_):
for j, one in enumerate(row):
if one != 0: state_[i][j] *= -1
return state_ def genEqualStateAndAction(self, state_, action_): # Input: np.array, tuple(x,y)
state, action = state_.copy(), action_
equalStates, equalActions = [], [] # 原局面
equalStates.append(state)
equalActions.append(action) # 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转180°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(2):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转270°
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(3):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 水平翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'horizon')
action_state_tf = flip(action_state_tf, 'horizon')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) # 旋转90° + 垂直翻转
state_tf = state.copy()
action_state_tf = np.zeros(state.shape)
action_state_tf[action] = 1
for i in range(1):
state_tf = rotate(state_tf)
action_state_tf = rotate(action_state_tf)
state_tf = flip(state_tf, 'vertical')
action_state_tf = flip(action_state_tf, 'vertical')
index = np.where(action_state_tf == 1)
action_tf = (int(index[0]), int(index[1]))
equalStates.append(state_tf)
equalActions.append(action_tf) return equalStates, equalActions def addNewState(self, env_, currentMove): # 若当前状态不在Q表中,则新增状态
state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
eqStates, eqActions = self.genEqualStateAndAction(state, (0,0)) for one in eqStates:
if str(one) not in self.Q_table:
self.Q_table[str(one)] = {}
actions = self.getEmptyPos(one)
for action in actions:
self.Q_table[str(one)][str(action)] = 0 def epsilon_greedy(self, env_, currentMove): # ε-贪心策略
state = env_.state if currentMove == 'blue' else self.overTurn(env_.state) # 如果是红方行动则翻转状态
Q_Sa = self.Q_table[str(state)]
maxAction, maxValue, otherAction = [], -100, []
for one in Q_Sa:
if Q_Sa[one] > maxValue:
maxValue = Q_Sa[one]
for one in Q_Sa:
if Q_Sa[one] == maxValue:
maxAction.append(str2tuple(one))
else:
otherAction.append(str2tuple(one)) try:
action_pos = random.choice(maxAction) if random.random() > self.EPSILON else random.choice(otherAction)
except: # 处理从空的otherAction中取值的情况
action_pos = random.choice(maxAction)
action = {'mark':currentMove, 'pos':action_pos}
return action def updateQtable(self, env_, currentMove, done_): judge = (currentMove == 'blue' and self.lastState_blue is None) or \
(currentMove == 'red' and self.lastState_red is None)
if judge: # 边界情况1:若agent无上一状态,说明是游戏中首次动作,那么只需要新增状态就好,无需更新Q值
self.addNewState(env_, currentMove)
return if done_: # 边界情况2:若当前状态S_是终止状态,则无需把S_添加至Q表格中,直接令maxQ_S_a = 0,并同时更新双方Q值
for one in ['blue', 'red']:
S = self.lastState_blue if one == 'blue' else self.lastState_red
a = self.lastAction_blue if one == 'blue' else self.lastAction_red
eqStates, eqActions = self.genEqualStateAndAction(S, a)
R = self.lastReward_blue if one == 'blue' else self.lastReward_red
# print('lastState S:\n', S)
# print('lastAction a: ', a)
# print('lastReward R: ', R)
# print('\n')
maxQ_S_a = 0
for S, a in zip(eqStates, eqActions):
self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
+ self.ALPHA * (R + self.GAMMA * maxQ_S_a)
return # 其他情况下:Q表无当前状态则新增状态,否则直接更新Q值
self.addNewState(env_, currentMove)
S_ = env_.state if currentMove == 'blue' else self.overTurn(env_.state)
S = self.lastState_blue if currentMove == 'blue' else self.lastState_red
a = self.lastAction_blue if currentMove == 'blue' else self.lastAction_red
eqStates, eqActions = self.genEqualStateAndAction(S, a)
R = self.lastReward_blue if currentMove == 'blue' else self.lastReward_red
# print('lastState S:\n', S)
# print('State S_:\n', S_)
# print('lastAction a: ', a)
# print('lastReward R: ', R)
# print('\n')
Q_S_a = self.Q_table[str(S_)]
maxQ_S_a = -100
for one in Q_S_a:
if Q_S_a[one] > maxQ_S_a:
maxQ_S_a = Q_S_a[one]
for S, a in zip(eqStates, eqActions):
self.Q_table[str(S)][str(a)] = (1 - self.ALPHA) * self.Q_table[str(S)][str(a)] \
+ self.ALPHA * (R + self.GAMMA * maxQ_S_a) env = gym.make('TicTacToeEnv-v0')
game = Game(env)
for i in range(10000):
print('episode', i)
game.run()
Q_table = game.agent.Q_table

测试

经过了上述优化,agent能够在一轮对局中更新16个Q值,比起上一节 强化学习实战 | 表格型Q-Learning玩井字棋(二)开始训练! 中的更新2个Q值要多8倍,不妨就玩1万局游戏,看看是否能玩出之前玩8万局游戏的效果。

项目1:查看Q表格的状态数

一般般,仍然有状态没有覆盖到。

项目2:查看初始状态

先手开局:

这效果也太好了吧!不但有完美的对称,还有泾渭分明的胜负判断: 第一步走四边就稳了,走四角和走中间都是输面大。看来优化之后,Q值的整体方差这一块表现得非常好了。

再贴一个后手开局的情况:

项目3:测试代码时间

引入了更复杂的trick,确实是完美地争取到了一些收益,但玩一局游戏的时间一定是增加了,增加了多少呢?我们用上一节的老算法和本节的算法分别跑2000局游戏,记录一下时间(本人使用的CPU是:Intel(R) Core(TM) i7-9750H)。

双向更新+等价局面同步更新:

双向更新:

增加了不到两倍的时间,换来了大约8倍的更新量提高,还降低了方差,看来这优化是赚的。

小结

拿着优化好的算法,心里也有了些底气,可以放心大胆地增加训练时间了。下一节,我们将用训练完全Q表,用pygame做一个拥有人机对阵,机机对战,作弊功能的井字棋游戏。还可以做一些对战的数据分析,比如AI内战的胜率多高?AI对阵随机策略的胜率多高?下节见!

强化学习实战 | 表格型Q-Learning玩井子棋(三)优化,优化的更多相关文章

  1. 强化学习实战 | 表格型Q-Learning玩井字棋(二)

    在 强化学习实战 | 表格型Q-Learning玩井字棋(一)中,我们构建了以Game() 和 Agent() 类为基础的框架,本篇我们要让agent不断对弈,维护Q表格,提升棋力.那么我们先来盘算一 ...

  2. 强化学习实战 | 表格型Q-Learning玩井字棋(四)游戏时间

    在 强化学习实战 | 表格型Q-Learning玩井字棋(三)优化,优化 中,我们经过优化和训练,得到了一个还不错的Q表格,这一节我们将用pygame实现一个有人机对战,机机对战和作弊功能的井字棋游戏 ...

  3. 强化学习实战 | 表格型Q-Learning玩井字棋(一)

    在 强化学习实战 | 自定义Gym环境之井子棋 中,我们构建了一个井字棋环境,并进行了测试.接下来我们可以使用各种强化学习方法训练agent出棋,其中比较简单的是Q学习,Q即Q(S, a),是状态动作 ...

  4. 强化学习系列之:Deep Q Network (DQN)

    文章目录 [隐藏] 1. 强化学习和深度学习结合 2. Deep Q Network (DQN) 算法 3. 后续发展 3.1 Double DQN 3.2 Prioritized Replay 3. ...

  5. 强化学习实战 | 自定义Gym环境之井字棋

    在文章 强化学习实战 | 自定义Gym环境 中 ,我们了解了一个简单的环境应该如何定义,并使用 print 简单地呈现了环境.在本文中,我们将学习自定义一个稍微复杂一点的环境--井字棋.回想一下井字棋 ...

  6. 强化学习实战 | 自定义Gym环境之扫雷

    开始之前 先考虑几个问题: Q1:如何展开无雷区? Q2:如何计算格子的提示数? Q3:如何表示扫雷游戏的状态? A1:可以使用递归函数,或是堆栈. A2:一般的做法是,需要打开某格子时,再去统计周围 ...

  7. 强化学习实战 | 自定义Gym环境

    新手的第一个强化学习示例一般都从Open Gym开始.在这些示例中,我们不断地向环境施加动作,并得到观测和奖励,这也是Gym Env的基本用法: state, reward, done, info = ...

  8. 强化学习实战 | 自定义gym环境之显示字符串

    如果想用强化学习去实现扫雷.2048这种带有数字提示信息的游戏,自然是希望自定义 gym 环境时能把字符显示出来.上网查了很久,没有找到gym自带的图形工具Viewer可以显示字符串的信息,反而是通过 ...

  9. 深度强化学习:入门(Deep Reinforcement Learning: Scratching the surface)

    RL的方案 两个主要对象:Agent和Environment Agent观察Environment,做出Action,这个Action会对Environment造成一定影响和改变,继而Agent会从新 ...

随机推荐

  1. 旧电脑做服务器--第一篇 sql server 服务器搭建

    背景:旧电脑为2015年的老电脑,联系G50系列,目前键盘鼠标操作都有问题,键盘按键和鼠标左键莫名奇妙变成右击,屏幕显示也是大颗粒.但是配置还可以,16GB内存+256GB三星固态硬盘.所以想搭建作为 ...

  2. 我罗斯方块第二次作业(Block类)

    负责任务 完善Block类的相关函数及变量: 对Block类的函数功能进行调试: github项目地址. 开发日记 2020.5.11 今天和朋友们讨论了如何分工的工作,我负责的部分是Block类的完 ...

  3. 72.Financial Management

    描述 Larry graduated this year and finally has a job. He's making a lot of money, but somehow never se ...

  4. OAuth 2.0 扩展协议之 PKCE

    前言 阅读本文前需要了解 OAuth 2.0 授权协议的相关内容, 可以参考我的上一篇文章 OAuth 2.0 的探险之旅. PKCE 全称是 Proof Key for Code Exchange, ...

  5. mysql 数据库中 int(3) 和 int(11) 有区别么???

    今天去面试的时候 面试官问到了这个问题:int(3) 和 int(11) 有什么区别?? 当时一听有点蒙,(不知道为什么蒙,后来回来想想可能是觉得考官怎么会问这么简单的问题呢,所以蒙了),当时我的回答 ...

  6. html+css第七篇-表格

    表格标签: table 表格 thead 表格头 tbody 表格主体 tfoot 表格尾 tr 表格行 th 元素定义表头 td 元素定义表格单元 表格样式重置 table{border-colla ...

  7. static关键字相关内容

    静态变量(static)与非静态变量,静态方法(static)与非静态方法 //static public class Student { private static int age; //静态的变 ...

  8. c# Quartzs定时器的简单使用

    使用背景: 首先想到倒计时,定时任务.大家想到的肯定就是定时器.那么定时器在web和winfrom程序中有着很大的作用.那在服务器端有没有像定时器一样的存在呢. 有了这些需求,我找到了第三方的组件 Q ...

  9. sb 错误

    数组开小.很容易 \(2 \times 10^5\) 或 \(10^6\) 就开成 \(10^5\),或者各种变量的数据范围混用,\(m \leq 5\times 10^5\),结果只开到了 \(n\ ...

  10. FJD1T1

    在考场上因为一些原因,系统编译不了. 于是在最后\(1h\)把\(T3\)得重打一遍,所以这题的暴力没有写完. 不过也确实很蠢,没想到做法. 考虑搜索原串中的字母的对应取值,然后计算出结果的柿子. 考 ...