小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于《动手学深度学习》pytorch版,也是代码较多,解释较少的一篇。不过好多方法在我以前的博客都有提,所以这次没提。还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂(前提是python语法大概了解),这是我不加很多解释的重要原因。
K折交叉验证实现
def get_k_fold_data(k, i, X, y):
# 返回第i折交叉验证时所需要的训练和验证数据,分开放,X_train为训练数据,X_valid为验证数据
assert k > 1
fold_size = X.shape[0] // k # 双斜杠表示除完后再向下取整
X_train, y_train = None, None
for j in range(k):
idx = slice(j * fold_size, (j + 1) * fold_size) #slice(start,end,step)切片函数
X_part, y_part = X[idx, :], y[idx]
if j == i:
X_valid, y_valid = X_part, y_part
elif X_train is None:
X_train, y_train = X_part, y_part
else:
X_train = torch.cat((X_train, X_part), dim=0) #dim=0增加行数,竖着连接
y_train = torch.cat((y_train, y_part), dim=0)
return X_train, y_train, X_valid, y_valid def k_fold(k, X_train, y_train, num_epochs,learning_rate, weight_decay, batch_size):
train_l_sum, valid_l_sum = 0, 0
for i in range(k):
data = get_k_fold_data(k, i, X_train, y_train) # 获取k折交叉验证的训练和验证数据
net = get_net(X_train.shape[1]) #get_net在这是一个基本的线性回归模型,方法实现见附录1
train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
weight_decay, batch_size) #train方法见后面附录2
train_l_sum += train_ls[-1]
valid_l_sum += valid_ls[-1]
if i == 0:
d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',
range(1, num_epochs + 1), valid_ls,
['train', 'valid']) #画图,且是对y求对数了,x未变。方法实现见附录3
print('fold %d, train rmse %f, valid rmse %f' % (i, train_ls[-1], valid_ls[-1]))
return train_l_sum / k, valid_l_sum / k
*args:表示接受任意长度的参数,然后存放入一个元组中;如def fun(*args) print(args),‘fruit','animal','human'作为参数传进去,输出(‘fruit','animal','human')
**kwargs:表示接受任意长的参数,然后存放入一个字典中;如
def fun(**kwargs):
for key, value in kwargs.items():
print("%s:%s" % (key,value)
fun(a=1,b=2,c=3)会输出 a=1 b=2 c=3
附录1
loss = torch.nn.MSELoss() def get_net(feature_num):
net = nn.Linear(feature_num, 1)
for param in net.parameters():
nn.init.normal_(param, mean=0, std=0.01)
return net
附录2
def train(net, train_features, train_labels, test_features, test_labels, num_epochs, learning_rate,weight_decay, batch_size):
train_ls, test_ls = [], []
dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #TensorDataset和DataLoader的使用请查看我以前的博客 #这里使用了Adam优化算法
optimizer = torch.optim.Adam(params=net.parameters(), lr= learning_rate, weight_decay=weight_decay)
net = net.float()
for epoch in range(num_epochs):
for X, y in train_iter:
l = loss(net(X.float()), y.float())
optimizer.zero_grad()
l.backward()
optimizer.step()
train_ls.append(log_rmse(net, train_features, train_labels))
if test_labels is not None:
test_ls.append(log_rmse(net, test_features, test_labels))
return train_ls, test_ls
附录3
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):
set_figsize(figsize)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.semilogy(x_vals, y_vals)
if x2_vals and y2_vals:
plt.semilogy(x2_vals, y2_vals, linestyle=':')
plt.legend(legend)
注:由于最近有其他任务,所以此博客写的匆忙,等我有时间后会丰富,也可能加详细解释。
小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)的更多相关文章
- 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())
先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...
- Python——决策树实战:california房价预测
Python——决策树实战:california房价预测 编译环境:Anaconda.Jupyter Notebook 首先,导入模块: import pandas as pd import matp ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 深度学习之PyTorch实战(1)——基础学习及搭建环境
最近在学习PyTorch框架,买了一本<深度学习之PyTorch实战计算机视觉>,从学习开始,小编会整理学习笔记,并博客记录,希望自己好好学完这本书,最后能熟练应用此框架. PyTorch ...
随机推荐
- 使用vue框架开发前端项目的步骤
前端项目的开发 1. 本地安装nodejs https://nodejs.org/en/download/ 2. 测试安装 > node -v 3. 本地安装git > git --ver ...
- 009、Java中超过了int的最大值或最小值的结果
01.代码如下: package TIANPAN; /** * 此处为文档注释 * * @author 田攀 微信382477247 */ public class TestDemo { public ...
- Java基础学习总结(二)
Java语言的特点: Java语言是简单的 Java语言是面向对象的 Java语言是跨平台(操作系统)的(即一次编写,到处运行) Java是高性能的 运行Java程序要安装和配置JDK jdk是什么? ...
- mysql limit查询入门
- 重装系统,新安装IDEA启动项目后,classnotfound:com.mysql.jdbc.Driver
我最后解决是:这个Test connection会自动帮你下载的,但是如果中途一直叫你try again,甚至到后面点这个test connection有弹窗,但是单窗里面的选项你点击后没反应,我是直 ...
- HDU - 6006 Engineer Assignment (状压dfs)
题意:n个工作,m个人完成,每个工作有ci个阶段,一个人只能选择一种工作完成,可以不选,且只能完成该工作中与自身标号相同的工作阶段,问最多能完成几种工作. 分析: 1.如果一个工作中的某个工作阶段没有 ...
- UVA - 10562 Undraw the Trees(多叉树的dfs)
题意:将多叉树转化为括号表示法. 分析:gets读取,dfs就好了.注意,样例中一行的最后一个字母后是没有空格的. #pragma comment(linker, "/STACK:10240 ...
- 51nod 1352:集合计数
1352 集合计数 基准时间限制:1 秒 空间限制:131072 KB 分值: 20 难度:3级算法题 收藏 关注 给出N个固定集合{1,N},{2,N-1},{3,N-2},...,{N-1,2 ...
- mysql基本知识的总结
Mysql基本sql知识 Navicat快捷方式: 选中当前行 在行尾:shift+home 在行首:shift+end 执行当前行:ctrl+shift+R 复制当前行:ctrl+D 显示所有数据库 ...
- cf 762D. Maximum path
天呢,好神奇的一个DP23333%%%%% 因为1.向左走1格的话相当于当前列和向左走列全选 2.想做走超过1的话可以有上下走替代.而且只能在相邻行向左. 全选的情况只能从第1行和第3行转移,相反全选 ...