Keras class_weight和sample_weight用法
搬运: https://stackoverflow.com/questions/57610804/when-is-the-timing-to-use-sample-weights-in-keras
import tensorflow as tf
import numpy as np
data_size = 100
input_size=3
classes=3
x_train = np.random.rand(data_size ,input_size)
y_train= np.random.randint(0,classes,data_size )
#sample_weight_train = np.random.rand(data_size)
x_val = np.random.rand(data_size ,input_size)
y_val= np.random.randint(0,classes,data_size )
#sample_weight_val = np.random.rand(data_size )
inputs = tf.keras.layers.Input(shape=(input_size))
pred=tf.keras.layers.Dense(classes, activation='softmax')(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=pred)
loss = tf.keras.losses.sparse_categorical_crossentropy
metrics = tf.keras.metrics.sparse_categorical_accuracy
model.compile(loss=loss , metrics=[metrics], optimizer='adam')
# Make model static, so we can compare it between different scenarios
for layer in model.layers:
layer.trainable = False
# base model no weights (same result as without class_weights)
# model.fit(x=x_train,y=y_train, validation_data=(x_val,y_val))
class_weights={0:1.,1:1.,2:1.}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 1.1882 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1965 - val_sparse_categorical_accuracy: 0.3100
#changing the class weights to zero, to check which loss and metric that is affected
class_weights={0:0,1:0,2:0}
model.fit(x=x_train,y=y_train, class_weight=class_weights, validation_data=(x_val,y_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1945 - val_sparse_categorical_accuracy: 0.3100
#changing the sample_weights to zero, to check which loss and metric that is affected
sample_weight_train = np.zeros(100)
sample_weight_val = np.zeros(100)
model.fit(x=x_train,y=y_train,sample_weight=sample_weight_train, validation_data=(x_val,y_val,sample_weight_val))
# which outputs:
> loss: 0.0000e+00 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.1931 - val_sparse_categorical_accuracy: 0.3100
class_weight: output 变量的权重
sample_weight: data sample 的权重
Keras class_weight和sample_weight用法的更多相关文章
- keras中TimeDistributed的用法
TimeDistributed这个层还是比较难理解的.事实上通过这个层我们可以实现从二维像三维的过渡,甚至通过这个层的包装,我们可以实现图像分类视频分类的转化. 考虑一批32个样本,其中每个样本是一个 ...
- Keras 学习之旅(一)
软件环境(Windows): Visual Studio Anaconda CUDA MinGW-w64 conda install -c anaconda mingw libpython CNTK ...
- Keras官方中文文档:函数式模型API
\ 函数式模型接口 为什么叫"函数式模型",请查看"Keras新手指南"的相关部分 Keras的函数式模型为Model,即广义的拥有输入和输出的模型,我们使用M ...
- Keras官方中文文档:序贯模型API
Sequential模型接口 如果刚开始学习Sequential模型,请首先移步这里阅读文档,本节内容是Sequential的API和参数介绍. 常用Sequential属性 model.layers ...
- keras系列︱Sequential与Model模型、keras基本结构功能(一)
引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...
- Python机器学习笔记:深入理解Keras中序贯模型和函数模型
先从sklearn说起吧,如果学习了sklearn的话,那么学习Keras相对来说比较容易.为什么这样说呢? 我们首先比较一下sklearn的机器学习大致使用流程和Keras的大致使用流程: skl ...
- 深度学习(六)keras常用函数学习
原文作者:aircraft 原文链接:https://www.cnblogs.com/DOMLX/p/9769301.html Keras是什么? Keras:基于Theano和TensorFlow的 ...
- Keras(一)Sequential与Model模型、Keras基本结构功能
keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...
- Keras Model Sequential模型接口
Sequential 模型 API 在阅读这片文档前,请先阅读 Keras Sequential 模型指引. Sequential 模型方法 compile compile(optimizer, lo ...
随机推荐
- java:easyui(jQueryEasyUI,分页)
1.介绍: jQuery EasyUI是一组基于jQuery的UI插件集合体,而jQuery EasyUI的目标就是帮助web开发者更轻松的打造出功能丰富并且美观的UI界面.开发者不需要编写复杂的ja ...
- leaflet的入门开发(一)
2016年9月27日—1.0leaflet,最快的,最稳定和严谨的leaflet,终于出来了! leaflet是领先的开源JavaScript库为移动设备设计的互动地图.重33 KB的JS,所有映射大 ...
- linux中权限对文件和目录的影响?
######### rwx 权限对文件和目录的含义 ############ 代表字符 权限 对文件的含义 对 目录 ...
- js多张图片合成一张图,canvas(海报图,将二维码和背景图合并) -----vue
思路:vue中图片合并 首先准备好要合并的背景图,和请求后得到的二维码, canvas画图,将两张背景图和一张二维码用canvas画出来, 将canvas再转为img 注意canvas和图片的清晰图和 ...
- Akka系列(三):监管与容错
前言...... Akka作为一种成熟的生产环境并发解决方案,必须拥有一套完善的错误异常处理机制,本文主要讲讲Akka中的监管和容错. 监管 看过我上篇文章的同学应该对Actor系统的工作流程有了一定 ...
- CDH6.2扩容
参考: yum方式扩容: https://www.cnblogs.com/yinzhengjie/articles/11104776.html 二进制包方式扩容: https://www.cnblog ...
- [转帖]LSB
LSB 简介 冯 锐2006 年 8 月 07 日发布 https://www.ibm.com/developerworks/cn/linux/l-lsb-intr/ 学习一下 之前 不知道LSB_R ...
- ORA-01406:提取的列值被截断 ; SQL Server :将截断字符串或二进制数据
oracle 数据库可以正常连接,表数据也可以正常读取, 但在程序中相同的位置,有时会报错,有时不会报错,有的电脑会报错,有的不会 报错内容为 ORA-01406:提取的列值被截断 查了网上提供的一些 ...
- 洛谷 P2801 教主的魔法 题解
题面 刚看到这道题的时候用了个树状数组优化前缀和差分的常数优化竟然AC了?(这数据也太水了吧~) 本人做的第一道分块题,调试了好久好久,最后竟然没想到二分上还会出错!(一定要注意)仅此纪念: #inc ...
- python 3.x报错:No module named 'cookielib'或No module named 'urllib2'
1. ModuleNotFoundError: No module named 'cookielib' Python3中,import cookielib改成 import http.coo ...