Sequential模型接口

如果刚开始学习Sequential模型,请首先移步这里阅读文档,本节内容是Sequential的API和参数介绍。

常用Sequential属性

  • model.layers是添加到模型上的层的list

Sequential模型方法

add

add(self, layer)

向模型中添加一个层

  • layer: Layer对象

pop

pop(self)

弹出模型最后的一层,无返回值


compile

compile(self, optimizer, loss, metrics=None, sample_weight_mode=None)

编译用来配置模型的学习过程,其参数有

  • optimizer:字符串(预定义优化器名)或优化器对象,参考优化器

  • loss:字符串(预定义损失函数名)或目标函数,参考损失函数

  • metrics:列表,包含评估模型在训练和测试时的网络性能的指标,典型用法是metrics=['accuracy']

  • sample_weight_mode:如果你需要按时间步为样本赋权(2D权矩阵),将该值设为“temporal”。默认为“None”,代表按样本赋权(1D权)。在下面fit函数的解释中有相关的参考内容。

  • weighted_metrics: metrics列表,在训练和测试过程中,这些metrics将由sample_weightclss_weight计算并赋权

  • target_tensors: 默认情况下,Keras将为模型的目标创建一个占位符,该占位符在训练过程中将被目标数据代替。如果你想使用自己的目标张量(相应的,Keras将不会在训练时期望为这些目标张量载入外部的numpy数据),你可以通过该参数手动指定。目标张量可以是一个单独的张量(对应于单输出模型),也可以是一个张量列表,或者一个name->tensor的张量字典。

  • kwargs:使用TensorFlow作为后端请忽略该参数,若使用Theano/CNTK作为后端,kwargs的值将会传递给 K.function。如果使用TensorFlow为后端,这里的值会被传给tf.Session.run

model = Sequential()
model.add(Dense(32, input_shape=(500,)))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])

模型在使用前必须编译,否则在调用fit或evaluate时会抛出异常。

fit

fit(self, x, y, batch_size=32, epochs=10, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0)

本函数将模型训练nb_epoch轮,其参数有:

  • x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array

  • y:标签,numpy array

  • batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。

  • epochs:整数,训练终止时的epoch值,训练将在达到该epoch值时停止,当没有设置initial_epoch时,它就是训练的总轮数,否则训练的总轮数为epochs - inital_epoch

  • verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

  • callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数

  • validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。

  • validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt。

  • shuffle:布尔值或字符串,一般为布尔值,表示是否在训练过程中随机打乱输入样本的顺序。若为字符串“batch”,则是用来处理HDF5数据的特殊情况,它将在batch内部将数据打乱。

  • class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)

  • sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'

  • initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。

fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况


evaluate

evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None)

本函数按batch计算在某些输入数据上模型的误差,其参数有:

  • x:输入数据,与fit一样,是numpy array或numpy array的list

  • y:标签,numpy array

  • batch_size:整数,含义同fit的同名参数

  • verbose:含义同fit的同名参数,但只能取0或1

  • sample_weight:numpy array,含义同fit的同名参数

本函数返回一个测试误差的标量值(如果模型没有其他评价指标),或一个标量的list(如果模型还有其他的评价指标)。model.metrics_names将给出list中各个值的含义。

如果没有特殊说明,以下函数的参数均保持与fit的同名参数相同的含义

如果没有特殊说明,以下函数的verbose参数(如果有)均只能取0或1


predict

predict(self, x, batch_size=32, verbose=0)

本函数按batch获得输入数据对应的输出,其参数有:

函数的返回值是预测值的numpy array


train_on_batch

train_on_batch(self, x, y, class_weight=None, sample_weight=None)

本函数在一个batch的数据上进行一次参数更新

函数返回训练误差的标量值或标量值的list,与evaluate的情形相同。


test_on_batch

test_on_batch(self, x, y, sample_weight=None)

本函数在一个batch的样本上对模型进行评估

函数的返回与evaluate的情形相同


predict_on_batch

predict_on_batch(self, x)

本函数在一个batch的样本上对模型进行测试

函数返回模型在一个batch上的预测结果


fit_generator

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

利用Python的生成器,逐个生成数据的batch并进行训练。生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练

函数的参数是:

  • generator:生成器函数,生成器的输出应该为:

    • 一个形如(inputs,targets)的tuple

    • 一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束

  • steps_per_epoch:整数,当生成器返回steps_per_epoch次数据时计一个epoch结束,执行下一个epoch

  • epochs:整数,数据迭代的轮数

  • verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录

  • validation_data:具有以下三种形式之一

    • 生成验证集的生成器

    • 一个形如(inputs,targets)的tuple

    • 一个形如(inputs,targets,sample_weights)的tuple

  • validation_steps: 当validation_data为生成器时,本参数指定验证集的生成器返回次数

  • class_weight:规定类别权重的字典,将类别映射为权重,常用于处理样本不均衡问题。

  • sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'

  • workers:最大进程数

  • max_q_size:生成器队列的最大容量

  • pickle_safe: 若为真,则使用基于进程的线程。由于该实现依赖多进程,不能传递non picklable(无法被pickle序列化)的参数到生成器中,因为无法轻易将它们传入子进程中。

  • initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。

函数返回一个History对象

例子:

def generate_arrays_from_file(path):
while 1:
f = open(path)
for line in f:
# create Numpy arrays of input data
# and labels, from each line in the file
x, y = process_line(line)
yield (x, y)
f.close() model.fit_generator(generate_arrays_from_file('/my_file.txt'),
samples_per_epoch=10000, epochs=10)

evaluate_generator

evaluate_generator(self, generator, steps, max_q_size=10, workers=1, pickle_safe=False)

本函数使用一个生成器作为数据源评估模型,生成器应返回与test_on_batch的输入数据相同类型的数据。该函数的参数与fit_generator同名参数含义相同,steps是生成器要返回数据的轮数。


predict_generator

predict_generator(self, generator, steps, max_q_size=10, workers=1, pickle_safe=False, verbose=0)

本函数使用一个生成器作为数据源预测模型,生成器应返回与test_on_batch的输入数据相同类型的数据。该函数的参数与fit_generator同名参数含义相同,steps是生成器要返回数据的轮数。


艾伯特(http://www.aibbt.com/)国内第一家人工智能门户

Keras官方中文文档:序贯模型API的更多相关文章

  1. Keras官方中文文档:函数式模型API

    \ 函数式模型接口 为什么叫"函数式模型",请查看"Keras新手指南"的相关部分 Keras的函数式模型为Model,即广义的拥有输入和输出的模型,我们使用M ...

  2. Keras官方中文文档:序贯模型

    快速开始序贯(Sequential)模型 序贯模型是多个网络层的线性堆叠,也就是"一条路走到黑". 可以通过向Sequential模型传递一个layer的list来构造该模型: f ...

  3. Keras官方中文文档:关于Keras模型

    关于Keras模型 Keras有两种类型的模型,序贯模型(Sequential)和函数式模型(Model),函数式模型应用更为广泛,序贯模型是函数式模型的一种特殊情况. 两类模型有一些方法是相同的: ...

  4. Keras官方中文文档:Keras安装和配置指南(Windows)

    这里需要说明一下,笔者不建议在Windows环境下进行深度学习的研究,一方面是因为Windows所对应的框架搭建的依赖过多,社区设定不完全:另一方面,Linux系统下对显卡支持.内存释放以及存储空间调 ...

  5. Keras官方中文文档:常见问题与解答

    所属分类:Keras Keras FAQ:常见问题 如何引用Keras? 如何使Keras调用GPU? 如何在多张GPU卡上使用Keras "batch", "epoch ...

  6. Keras官方中文文档:keras后端Backend

    所属分类:Keras Keras后端 什么是"后端" Keras是一个模型级的库,提供了快速构建深度学习网络的模块.Keras并不处理如张量乘法.卷积等底层操作.这些操作依赖于某种 ...

  7. Keras官方中文文档:Keras安装和配置指南(Linux)

    关于计算机的硬件配置说明 推荐配置 如果您是高校学生或者高级研究人员,并且实验室或者个人资金充沛,建议您采用如下配置: 主板:X299型号或Z270型号 CPU: i7-6950X或i7-7700K ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

随机推荐

  1. Django搭建博客网站(一)

    Django搭建自己的博客网站(一) 简介 这个系列主要是通过使用Django这个python web框架实现一个简单的个人博客网站.对Django有疑问可以上Django官网查文档. 功能 后台管理 ...

  2. Gentoo(贱兔)Linux安装笔记

      网上对于Gentoo Linux 的教程少之又少,所以这里我将自己的安装记录贴出来 希望对正在研究Gentoo 的小伙伴们有帮助! 1.确认连接到互联网,使用net-setup工具配置网络 roo ...

  3. 阿里云pai项目使用说明

    PAI项目创建方法 购买region 进入MaxCompute,购买相应region,目前机器学习只支持华东2(GPU公测免费)以及华北2(GPU计划收费),注意选择"按量后付费" ...

  4. python 列表去重(数组)的几种方法(转)

    一.方法1  代码如下 复制代码 ids = [1,2,3,3,4,2,3,4,5,6,1] news_ids = [] for id in ids:     if id not in news_id ...

  5. [记录]MySQL读写分离(Atlas和MySQL-proxy)

    MySQL读写分离(Atlas和MySQL-proxy) 一.阿里云使用Atlas从外网访问MySQL(RDS) (同样的方式修改配置文件可以实现代理也可以实现读写分离,具体看使用场景) 1.在跳板机 ...

  6. Effective Java 第三版——33. 优先考虑类型安全的异构容器

    Tips <Effective Java, Third Edition>一书英文版已经出版,这本书的第二版想必很多人都读过,号称Java四大名著之一,不过第二版2009年出版,到现在已经将 ...

  7. Servlet中请求重定向和请求转发和include

    响应的重定向 response.sendRedirect("ShowMSgSerlet1");//请求重定向 会将后面的浏览器的url改变. 请求转发 RequestDispatc ...

  8. keras初涉笔记【一】

    安装keras依赖的库 sudo pip install numpy sudo pip install scipy sudo pip installl pyyaml sudo pipi install ...

  9. UVa230 Borrowers

    原题链接 UVa230 思路 这题输入时有一些字符串处理操作,可以利用string的substr()函数和find_last_of()函数更加方便,处理时不必更要把书名和作者对应下来,注意到原题书名的 ...

  10. JSP的内置对象以及作用域。

    最近在面试,一些基础的问题总是会被问到,虽然是基础,但是有些东西在工作中用的少,所以就有些记不清了,在面试的时候更因为紧张很容易造成原先知道的知识也会突然忘了的情况发生.所以在重新组织一下jsp的内置 ...