How much training data do you need?

 

//@樵夫上校: 0. 经验上,10X规则(训练数据是模型参数量的10倍)适用与大多数模型,包括shallow network. 1.线性模型可以应用10X的经验规则,模型参数是特征选择后的数量(PCA等方法)。2.NN可以将10X规则当做训练数据量的下限。

The quality and amount of training data is often the single most dominant factor that determines the performance of a model. Once you have the training data angle covered, the rest usually follows. But exactly how much training data do you need? The correct answer is: it depends. It depends on the task you are trying to perform, the performance you want to achieve, the input features you have, the noise in the training data, the noise in your extracted features, the complexity of your model and so on. So the way to find out the interaction of all these variables is to train your model on varying amounts of training data and plot learning curves. But this requires you to already have some decent amount of training data to construct interesting plots. What do you do when you are just starting out? Or when you suspect you have too little training data and want to estimate how big a problem you are in?

So instead of the dead accurate “correct” answer to the problem, how about an estimate, a practical rule of thumb? One way out is to take an empirical approach as follows. First, automatically generate a lot of logistic regression problems. For each generated problem, study the relationship between the amount of training data and the performance of the trained models. Observing this relationship over a range of problems, generalize to a simple rule.

 

Here is the code to generate a range of logistic regression problems and study the effect of varying the amount of training data. The code is based onTensorflow. Running the code doesn’t require any special software or hardware (Tensorflow is open sourced by Google), and I was able to run the entire experiment on my laptop. Upon running, the code spits out the graph below.

 

The x-axis is the ratio of the number of training samples to the number of model parameters. The y-axis is the f-score of the trained model. The curves in different colors correspond to models that differ in the number of parameters. For example, the red curve which corresponds to a model with 128 parameters indicate how the fscore changes as one varies the number of training samples to 128 x 1, 128 x 2 and so on.

The first observation is that the f-score curves don’t vary as the parameters scale. This is expected given the models are linear and it’s good to see that some hidden non-linearity doesn’t creep in. Of course, larger models need more training data, but for a given ratio of the number of training samples to the number of model parameters you get the same peformance. The second observation is that when the ratio of training samples to model parameters is 10:1, the f-score lands in the vicinity of 0.85 which we take as the definition of a well performing model. This leads us to the rule of 10, namely the amount of training data you need for a well performing model is 10x the number of parameters in the model.

The rule of 10 transforms the problem of estimating the amount of training data required to knowing the number of parameters in the model, so it deserves some discussion. For linear models such as logistic regression, the number of parameters equal the number of input features since the model assigns a parameter corresponding to each feature. However there could be some complications:

  • Your features may be sparse, so counting the number of features may not be straightforward.
  • Due to regularization and feature selection techniques a lot of features may be discarded, so the real feature count is much smaller than the number of raw features that are input to the model.

One way to tackle the issue is to observe that you don’t really need labeled data to get an estimate of the number of features, even unlabeled examples are sufficient for that purpose. For example, given a large corpus of text, you can generate histograms of word frequencies to understand your feature space before beginning to label the data for training. Given the histogram, you can discard the words in the long tail to get an estimate of the real feature count, which then gives an estimate of the amount of training data you need applying the rule of 10.

Neural networks pose a different set of problems than linear models like logistic regression. To get the number of parameters in a neural network you need to

  • Count the number of parameters used in the embedding layer if your input is sparse (see the Tensorflow tutorial on word embeddings for example).
  • Count the number of edges in your network.

The problem is the relationship between the parameters in a neural network is no longer linear, so the emperical study we did based on logistic regression doesn’t really apply anymore. In such cases you can treat the rule of 10 as a lower bound to the amount of training data needed.

Despite the complications above, in my experience the rule of 10 seem to work across a wide range of problems, including shallow neural nets. However when in doubt, plug in your own model and assumptions in the Tensorflow code and run the simulation to study it’s effects. Please feel free to share if you gain any insight in the process.

 

How much training data do you need?的更多相关文章

  1. 阅读笔记 The Impact of Imbalanced Training Data for Convolutional Neural Networks [DegreeProject2015] 数据分析型

    The Impact of Imbalanced Training Data for Convolutional Neural Networks Paulina Hensman and David M ...

  2. 什么情况下使用large training data会非常有效

    收集大量的数据可能比算法的优劣更重要 Banko和Brill在2001年做了一个研究,是关于在句子中对易混单词进行识别,画出了上图的右边的那个图,这个图显示了对于不同的算法,它们的表现相似,但是随着t ...

  3. 论文解读(SR-GNN)《Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data》

    论文信息 论文标题:Shift-Robust GNNs: Overcoming the Limitations of Localized Graph Training Data论文作者:Qi Zhu, ...

  4. [Tensorflow] Object Detection API - prepare your training data

    From: TensorFlow Object Detection API This chapter help you to train your own model to identify obje ...

  5. 16 On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima 1609.04836v1

    Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, Ping Tak Peter Tang N ...

  6. Data Transformation / Learning with Counts

    机器学习中离散特征的处理方法 Updated: August 25, 2016 Learning with counts is an efficient way to create a compact ...

  7. 一天一经典Reducing the Dimensionality of Data with Neural Networks [Science2006]

    别看本文没有几页纸,本着把经典的文多读几遍的想法,把它彩印出来看,没想到效果很好,比在屏幕上看着舒服.若用蓝色的笔圈出重点,这篇文章中几乎要全蓝.字字珠玑. Reducing the Dimensio ...

  8. Deep Learning in a Nutshell: History and Training

    Deep Learning in a Nutshell: History and Training This series of blog posts aims to provide an intui ...

  9. Deep Learning 16:用自编码器对数据进行降维_读论文“Reducing the Dimensionality of Data with Neural Networks”的笔记

    前言 论文“Reducing the Dimensionality of Data with Neural Networks”是深度学习鼻祖hinton于2006年发表于<SCIENCE > ...

随机推荐

  1. 制作苹果推送通知APNS服务器证书文件

    1.准备证书申请文件 打开苹果电脑实用工具里的钥匙串访问程序 选择钥匙串访问—>证书助理—>从证书颁发机构申请证书 输入邮件地址,常用名词随便命名,在这里命名为APNS 选择存储到磁盘,将 ...

  2. [tools]tcp/udp连通性测试

    一 端口连通性测试意义 测试网络端口可达性,确保给某些使用特定端口的app做链路连通性检测.使它们能够正常的运行起来.   二 法1 使用newclient发包,彼端tcpdump抓包观察是否能收到包 ...

  3. C语言 百炼成钢17

    //题目49:老师将糖果分成若干份,让学生按任意次序领取,第一个领取的,得到1份加上剩余糖果的1/10, //第二个领取的,得到2份加上剩余糖果的1/10,第三个领取的,得到3份加上剩余糖果的1/10 ...

  4. 信息安全系统设计基础第一次实验报告 20135201&&20135306&&20135307

    信息安全系统设计基础实验 班级: 201353 姓名:张忻 张嘉琪 黄韧 学号:20135301 20135307 20135306 实验日期:2015.11.10 实验名称: S3C2410的lin ...

  5. 《Linux及安全》实验安排

    SEED(SEcurity EDucation)项目由雪城大学杜文亮教授2002年创立,为计算机教学提供一套信息安全实验环境,目前已开发超过30个实验,涵盖广泛的安全原理,被全世界数百个高校采用. 实 ...

  6. 清除sql server 登录的时候记住的账户

    SQl 2008如何清除登陆过的服务器名称   C:\Users\Administrator\AppData\Roaming\Microsoft\Microsoft SQL Server\100\To ...

  7. linux下安装openssh-server

    csdn博文地址:linux下安装openssh-server   点击进入 系统是ubuntu14.04,系统默认安装了openssh-client,但没有安装openssh-server,需要手动 ...

  8. Bootstrap系列 -- 38. 基础导航条

    在制作一个基础导航条时,主要分以下几步: 第一步:首先在制作导航的列表(<ul class=”nav”>)基础上添加类名“navbar-nav” 第二步:在列表外部添加一个容器(div), ...

  9. XAMPP里tomcat启动报错:Make sure you have Java JDK or JRE installed and the required ports are free

    以前用XAMPP的时候就是自然而然装好了就可以用,最近重装了新系统,打算在Windows 10里面配置Apache tomcat.PHP.MySQL的开发环境,迟迟试验不成功,于是直接用了XAMPP, ...

  10. [USACO2003][poj2110]Mountain Walking(二分答案+bfs)

    http://poj.org/problem?id=2110 题意:给你一个n*n矩形(n<=100),每个位置上都有一个数字代表此处山的高度,要从(1,1)走到 (n,n),要求一条路径使得这 ...