WAF 强化学习
参考:https://github.com/duoergun0729/3book/tree/master/code/gym-waf
代码:
wafEnv.py
#-*- coding:utf-8 –*-
import numpy as np
import re
import random
from gym import spaces
import gym
from sklearn.model_selection import train_test_split #samples_file="xss-samples.txt"
samples_file="xss-samples-all.txt"
samples=[]
with open(samples_file) as f:
for line in f:
line = line.strip('\n')
print("Add xss sample:" + line)
samples.append(line) # 划分训练和测试集合
samples_train, samples_test = train_test_split(samples, test_size=0.4) class Xss_Manipulator(object):
def __init__(self):
self.dim = 0
self.name="" #常见免杀动作:
# 随机字符转16进制 比如: a转换成a;
# 随机字符转10进制 比如: a转换成a;
# 随机字符转10进制并假如大量0 比如: a转换成a;
# 插入注释 比如: /*abcde*/
# 插入Tab
# 插入回车
# 开头插入空格 比如: /**/
# 大小写混淆
# 插入 \00 也会被浏览器忽略 ACTION_TABLE = {
#'charTo16': 'charTo16',
#'charTo10': 'charTo10',
#'charTo10Zero': 'charTo10Zero',
'addComment': 'addComment',
'addTab': 'addTab',
'addZero': 'addZero',
'addEnter': 'addEnter',
} def charTo16(self,str,seed=None):
#print("charTo16")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#print("search --> matchObj.group() : ", matchObjs)
modify_char=random.choice(matchObjs)
#字符转ascii值ord(modify_char
#modify_char_10=ord(modify_char)
modify_char_16="&#{};".format(hex(ord(modify_char)))
#print("modify_char %s to %s" % (modify_char,modify_char_10))
#替换
str=re.sub(modify_char, modify_char_16, str,count=random.randint(1,3)) return str def charTo10(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#print("search --> matchObj.group() : ", matchObjs)
modify_char=random.choice(matchObjs)
#字符转ascii值ord(modify_char
#modify_char_10=ord(modify_char)
modify_char_10="&#{};".format(ord(modify_char))
#print("modify_char %s to %s" % (modify_char,modify_char_10))
#替换
str=re.sub(modify_char, modify_char_10, str) return str def charTo10Zero(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#print("search --> matchObj.group() : ", matchObjs)
modify_char=random.choice(matchObjs)
#字符转ascii值ord(modify_char
#modify_char_10=ord(modify_char)
modify_char_10="�{};".format(ord(modify_char))
#print("modify_char %s to %s" % (modify_char,modify_char_10))
#替换
str=re.sub(modify_char, modify_char_10, str) return str def addComment(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#选择替换的字符
modify_char=random.choice(matchObjs)
#生成替换的内容
#modify_char_comment="{}/*a{}*/".format(modify_char,modify_char)
modify_char_comment = "{}/*8888*/".format(modify_char) #替换
str=re.sub(modify_char, modify_char_comment, str) return str def addTab(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#选择替换的字符
modify_char=random.choice(matchObjs)
#生成替换的内容
modify_char_tab=" {}".format(modify_char) #替换
str=re.sub(modify_char, modify_char_tab, str) return str def addZero(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#选择替换的字符
modify_char=random.choice(matchObjs)
#生成替换的内容
modify_char_zero="\\00{}".format(modify_char) #替换
str=re.sub(modify_char, modify_char_zero, str) return str def addEnter(self,str,seed=None):
#print("charTo10")
matchObjs = re.findall(r'[a-qA-Q]', str, re.M | re.I)
if matchObjs:
#选择替换的字符
modify_char=random.choice(matchObjs)
#生成替换的内容
modify_char_enter="\\r\\n{}".format(modify_char) #替换
str=re.sub(modify_char, modify_char_enter, str) return str def modify(self,str, _action, seed=6): print("Do action :%s" % _action)
action_func=Xss_Manipulator().__getattribute__(_action) return action_func(str,seed) ACTION_LOOKUP = {i: act for i, act in enumerate(Xss_Manipulator.ACTION_TABLE.keys())} #<embed src="data:text/html;base64,PHNjcmlwdD5hbGVydCgxKTwvc2NyaXB0Pg==">
#a="get";b="URL(ja\"";c="vascr";d="ipt:ale";e="rt('XSS');\")";eval(a+b+c+d+e);
#"><script>alert(String.fromCharCode(66, 108, 65, 99, 75, 73, 99, 101))</script>
#<input onblur=write(XSS) autofocus><input autofocus>
#<math><a xlink:href="//jsfiddle.net/t846h/">click
#<h1><font color=blue>hellox worldss</h1>
#LOL<style>*{/*all*/color/*all*/:/*all*/red/*all*/;/[0]*IE,Safari*[0]/color:green;color:bl/*IE*/ue;}</style> class Waf_Check(object):
def __init__(self):
self.name="Waf_Check"
self.regXSS=r'(prompt|alert|confirm|expression])' \
r'|(javascript|script|eval)' \
r'|(onload|onerror|onfocus|onclick|ontoggle|onmousemove|ondrag)' \
r'|(String.fromCharCode)' \
r'|(;base64,)' \
r'|(onblur=write)' \
r'|(xlink:href)' \
r'|(color=)'
#self.regXSS = r'javascript' def check_xss(self,str):
isxss=False #忽略大小写
if re.search(self.regXSS,str,re.IGNORECASE):
isxss=True return isxss class Features(object):
def __init__(self):
self.dim = 0
self.name=""
self.dtype=np.float32 def byte_histogram(self,str):
#bytes=np.array(list(str))
bytes=[ord(ch) for ch in list(str)]
#print(bytes) h = np.bincount(bytes, minlength=256)
return np.concatenate([
[h.sum()], # total size of the byte stream
h.astype(self.dtype).flatten() / h.sum(), # normalized the histogram
]) def extract(self,str): featurevectors = [
[self.byte_histogram(str)]
]
return np.concatenate(featurevectors) class WafEnv_v0(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
} def __init__(self):
self.action_space = spaces.Discrete(len(ACTION_LOOKUP)) #xss样本特征集合
#self.samples=[]
#当前处理的样本
self.current_sample=""
#self.current_state=0
self.features_extra=Features()
self.waf_checker=Waf_Check()
#根据动作修改当前样本免杀
self.xss_manipulatorer= Xss_Manipulator() self._reset() def _seed(self, num):
pass def _step(self, action): r=0
is_gameover=False
#print("current sample:%s" % self.current_sample) _action=ACTION_LOOKUP[action]
#print("action is %s" % _action) self.current_sample=self.xss_manipulatorer.modify(self.current_sample,_action)
#print("change current sample to %s" % self.current_sample) if not self.waf_checker.check_xss(self.current_sample):
#给奖励
r=10
is_gameover=True
print("Good!!!!!!!avoid waf:%s" % self.current_sample) self.observation_space=self.features_extra.extract(self.current_sample) return self.observation_space, r,is_gameover,{} def _reset(self):
self.current_sample=random.choice(samples_train)
print("reset current_sample=" + self.current_sample) self.observation_space=self.features_extra.extract(self.current_sample)
return self.observation_space def render(self, mode='human', close=False):
return
主代码:
#-*- coding:utf-8 –*-
import gym
import time
import random
import gym_waf.envs.wafEnv
import pickle
import numpy as np from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, ELU, Dropout, BatchNormalization
from keras.optimizers import Adam, SGD, RMSprop from rl.agents.dqn import DQNAgent
from rl.agents.sarsa import SarsaAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory from gym_waf.envs.wafEnv import samples_test,samples_train
# from gym_waf.envs.features import Features
from gym_waf.envs.waf import Waf_Check
from gym_waf.envs.xss_manipulator import Xss_Manipulator from keras.callbacks import TensorBoard ENV_NAME = 'Waf-v0'
#尝试的最大次数
nb_max_episode_steps_train=50
nb_max_episode_steps_test=3 ACTION_LOOKUP = {i: act for i, act in enumerate(Xss_Manipulator.ACTION_TABLE.keys())} class Features(object):
def __init__(self):
self.dim = 0
self.name=""
self.dtype=np.float32 def byte_histogram(self,str):
#bytes=np.array(list(str))
bytes=[ord(ch) for ch in list(str)]
#print(bytes) h = np.bincount(bytes, minlength=256)
return np.concatenate([
[h.sum()], # total size of the byte stream
h.astype(self.dtype).flatten() / h.sum(), # normalized the histogram
]) def extract(self,str): featurevectors = [
[self.byte_histogram(str)]
]
return np.concatenate(featurevectors) def generate_dense_model(input_shape, layers, nb_actions):
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dropout(0.1)) for layer in layers:
model.add(Dense(layer))
model.add(BatchNormalization())
model.add(ELU(alpha=1.0)) model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary()) return model def train_dqn_model(layers, rounds=10000): env = gym.make(ENV_NAME)
env.seed(1)
nb_actions = env.action_space.n
window_length = 1 print("nb_actions:")
print(nb_actions)
print("env.observation_space.shape:")
print(env.observation_space.shape) model = generate_dense_model((window_length,) + env.observation_space.shape, layers, nb_actions) policy = EpsGreedyQPolicy() memory = SequentialMemory(limit=256, ignore_episode_boundaries=False, window_length=window_length) agent = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=16,
enable_double_dqn=True, enable_dueling_network=True, dueling_type='avg',
target_model_update=1e-2, policy=policy, batch_size=16) agent.compile(RMSprop(lr=1e-3), metrics=['mae']) #tb_cb = TensorBoard(log_dir='/tmp/log', write_images=1, histogram_freq=1)
#cbks = [tb_cb]
# play the game. learn something!
#nb_max_episode_steps 一次学习周期中最大步数
agent.fit(env, nb_steps=rounds, nb_max_episode_steps=nb_max_episode_steps_train,visualize=False, verbose=2) #print("#################Start Test%################") #agent.test(env, nb_episodes=100) test_samples=samples_test features_extra = Features()
waf_checker = Waf_Check()
# 根据动作修改当前样本免杀
xss_manipulatorer = Xss_Manipulator() success=0
sum=0 shp = (1,) + tuple(model.input_shape[1:]) for sample in samples_test:
#print(sample)
sum+=1 for _ in range(nb_max_episode_steps_test): if not waf_checker.check_xss(sample) :
success+=1
print(sample)
break f = features_extra.extract(sample).reshape(shp)
act_values = model.predict(f)
action=np.argmax(act_values[0])
sample=xss_manipulatorer.modify(sample,ACTION_LOOKUP[action]) print("Sum:{} Success:{}".format(sum,success)) return agent, model if __name__ == '__main__':
agent1, model1= train_dqn_model([5, 2], rounds=1000)
model1.save('waf-v0.h5', overwrite=True)
效果:
reset current_sample=<img src=`xx:xx`onerror=alert(1)>
Do action :addEnter
Do action :addComment
Good!!!!!!!avoid waf:<img src=`xx:xx`
one/*8888*/rr
or=ale/*8888*/rt(1)>
987/1000: episode: 221, duration: 0.016s, episode steps: 2, steps per second: 122, episode reward: 10.000, mean reward: 5.000 [0.000, 10.000], mean action: 1.500 [0.000, 3.000], mean observation: 0.179 [0.000, 53.000], loss: 1.608465, mean_absolute_error: 3.369818, mean_q: 7.756353
reset current_sample=<!--<img src="--><img src=x onerror=alert(123)//">
Do action :addEnter
Do action :addEnter
Do action :addEnter
Do action :addZero
Do action :addEnter
Do action :addEnter
Do action :addEnter
Do action :addEnter
Do action :addEnter
Good!!!!!!!avoid waf:<!--<
WAF 强化学习的更多相关文章
- 【整理】强化学习与MDP
[入门,来自wiki] 强化学习是机器学习中的一个领域,强调如何基于环境而行动,以取得最大化的预期利益.其灵感来源于心理学中的行为主义理论,即有机体如何在环境给予的奖励或惩罚的刺激下,逐步形成对刺激的 ...
- 强化学习之 免模型学习(model-free based learning)
强化学习之 免模型学习(model-free based learning) ------ 蒙特卡罗强化学习 与 时序查分学习 ------ 部分节选自周志华老师的教材<机器学习> 由于现 ...
- (译) 强化学习 第一部分:Q-Learning 以及相关探索
(译) 强化学习 第一部分:Q-Learning 以及相关探索 Q-Learning review: Q-Learning 的基础要点是:有一个关于环境状态S的表达式,这些状态中可能的动作 a,然后你 ...
- 强化学习读书笔记 - 02 - 多臂老O虎O机问题
# 强化学习读书笔记 - 02 - 多臂老O虎O机问题 学习笔记: [Reinforcement Learning: An Introduction, Richard S. Sutton and An ...
- 强化学习读书笔记 - 05 - 蒙特卡洛方法(Monte Carlo Methods)
强化学习读书笔记 - 05 - 蒙特卡洛方法(Monte Carlo Methods) 学习笔记: Reinforcement Learning: An Introduction, Richard S ...
- 强化学习读书笔记 - 06~07 - 时序差分学习(Temporal-Difference Learning)
强化学习读书笔记 - 06~07 - 时序差分学习(Temporal-Difference Learning) 学习笔记: Reinforcement Learning: An Introductio ...
- 强化学习之Q-learning ^_^
许久没有更新重新拾起,献于小白 这次介绍的是强化学习 Q-learning,Q-learning也是离线学习的一种 关于Q-learning的算法详情看 传送门 下文中我们会用openai gym来做 ...
- 强化学习 - Q-learning Sarsa 和 DQN 的理解
本文用于基本入门理解. 强化学习的基本理论 : R, S, A 这些就不说了. 先设想两个场景: 一. 1个 5x5 的 格子图, 里面有一个目标点, 2个死亡点二. 一个迷宫, 一个出发点, ...
- TensorLayer官方中文文档1.7.4:API – 强化学习
API - 强化学习¶ 强化学习(增强学习)相关函数. discount_episode_rewards([rewards, gamma, mode]) Take 1D float array of ...
随机推荐
- MongoDB-3: 查询(一)
一.简介 MongoDB提供了db.collection.find() 方法可以实现根据条件查询和指定使用投影运算符返回的字段省略此参数返回匹配文档中的所有字段. 二.db.collection.fi ...
- Django HttpRequest对象详解
WSGIRequest对象 Django在接收到http请求之后,会根据http请求携带的参数以及报文信息创建一个WSGIRequest对象,并且作为视图函数第一个参数传给视图函数.也就是我们经常看到 ...
- Spring-Hello World实例
Spring Hello World实例 创建Java项目 添加Jar包 创建源文件 现在在Spring项目下创建实际的源文件.首先,要创建一个名为com.tuorialsponit的包,然后在该co ...
- Hadoop源码如何查看
如何查看hadoop源码 1解压hadoop安装压缩文件成为文件夹,再进入解压后的文件夹下的src文件夹,选中core,hdfs,mapred三个文件夹
- Django-MTV(Day66)
阅读目录 Django基本命令 视图层路由配置系统 视图层之视图函数 MTV模型 Django的MTV分别代表: Model(模型):负责业务对象与数据库的对象(ORM) Template(模板):负 ...
- BCH码
http://baike.baidu.com/link?url=CfLtm9DigwWdup-9VJP99RG65NgaVOXfrnjT61ogP7au0QOrlypq72k67B0s1Ey-Q1yD ...
- s5_day1作业
#1.使用while循环输出1 2 3 4 5 6 8 9 10 # s=0 # while s<10: # s+=1 # if s==7: # continue # print(s) # fo ...
- slf4j-api、slf4j-log4j12以及log4j之间什么关系?
几乎在每个jar包里都可以看到log4j的身影,在多个子工程构成项目中,slf4j相关的冲突时不时就跳出来让你不爽,那么slf4j-api.slf4j-log4j12还有log4j他们是什么关系?我把 ...
- python之路 线程、进程、协程、队列、python-memcache、python-redis
一.线程 Threading用于提供线程相关的操作,线程是应用程序中工作的最小单元. #!/usr/bin/env python # -*- coding:utf-8 -*- import threa ...
- 了解IE中filter属性的应用!
在设置不透明属性时,经常用opacity来增加层次感或者增加用户体验,但这个属性是css3属性,对于低级浏览器的兼容性来说就达不到预期的效果. 一般而言,我们都尽可能少用一些浏览私有属性-webkit ...