pytorch kaggle 泰坦尼克生存预测

也不知道对不对,就凭着自己的思路写了一个
数据集:https://www.kaggle.com/c/titanic/data
import torch
import torch.nn as nn
import pandas as pd
import numpy as np class DataProcessing(object):
def __init__(self):
pass def get_data(self):
data_train = pd.read_csv('train.csv')
label = data_train[['Survived']]
data_test = pd.read_csv('test.csv')
# 读取指定列
gender = pd.read_csv('gender_submission.csv', usecols=[1])
return data_train, label, data_test, gender def data_processing(self, data_):
# 训练集测试集都进行相同的处理
data = data_[['Pclass', 'Sex', 'Age', 'SibSp', 'Fare', 'Cabin', 'Embarked']]
data['Age'] = data['Age'].fillna(data['Age'].mean())
data['Cabin'] = pd.factorize(data.Cabin)[0]
data.fillna(0, inplace=True)
data['Sex'] = [1 if x == 'male' else 0 for x in data.Sex]
data['p1'] = np.array(data['Pclass'] == 1).astype(np.int32)
data['p2'] = np.array(data['Pclass'] == 2).astype(np.int32)
data['p3'] = np.array(data['Pclass'] == 3).astype(np.int32)
data['e1'] = np.array(data['Embarked'] == 'S').astype(np.int32)
data['e2'] = np.array(data['Embarked'] == 'C').astype(np.int32)
data['e3'] = np.array(data['Embarked'] == 'Q').astype(np.int32)
del data['Pclass']
del data['Embarked']
return data def data(self):
# 读数据
train_data, label, test_data, gender = self.get_data()
# 处理数据
# 训练集输入数据
train = np.array(data_processing.data_processing(train_data))
# 训练集标签
train_label = np.array(label)
# 测试集
test = np.array(data_processing.data_processing(test_data))
# 测试集标签
test_label = np.array(gender) train = torch.from_numpy(train).float()
train_label = torch.tensor(train_label).float()
test = torch.tensor(test).float()
test_label = torch.tensor(test_label) return train, train_label, test, test_label class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc = nn.Sequential(
nn.Linear(11, 7),
nn.Sigmoid(),
nn.Linear(7, 7),
nn.Sigmoid(),
nn.Linear(7, 1),
)
self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
self.mls = nn.MSELoss() def forward(self, inputs):
# 前向传播
return self.fc(inputs) def train(self, inputs, y):
# 训练
out = self.forward(inputs)
loss = self.mls(out, y)
self.opt.zero_grad()
loss.backward()
self.opt.step()
# print(loss) def test(self, x, y):
# 测试
# 将variable张量转为numpy
# out = self.fc(x).data.numpy()
count = 0
out = self.fc(x)
sum = len(y)
for i, j in zip(out, y):
i = i.detach().numpy()
j = j.detach().numpy()
loss = abs((i - j)[0])
if loss < 0.3:
count += 1
# 误差0.3内的正确率
print(count/sum) if __name__ == '__main__':
data_processing = DataProcessing()
train_data, train_label, test_data, test_label = data_processing.data()
net = MyNet()
count = 0
for i in range(20000):
# 为了减小电脑压力,分批训练 100个训练一次 ## 2018.12.22补充:正确的做法应该是用batch
for n in range(len(train_data)//100 + 1):
batch_data = train_data[n*100: n*100 + 100]
batch_label = train_label[n*100: n*100 + 100]
net.train(train_data, train_label)
net.test(test_data, test_label) # 输出结果:0.7488038277511961
效果一般吧,不过至少出来了,hiahiahia
pytorch kaggle 泰坦尼克生存预测的更多相关文章
- 利用python进行泰坦尼克生存预测——数据探索分析
		
最近一直断断续续的做这个泰坦尼克生存预测模型的练习,这个kaggle的竞赛题,网上有很多人都分享过,而且都很成熟,也有些写的非常详细,我主要是在牛人们的基础上,按照数据挖掘流程梳理思路,然后通过练习每 ...
 - Kaggle初体验之泰坦尼特生存预测
		
Kaggle初体验之泰坦尼特生存预测 学习完了决策树的ID3.C4.5.CART算法,找一个试手的地方,Kaggle的练习赛泰坦尼特很不错,记录下 流程 首先注册一个账号,然后在顶部菜单栏Co ...
 - Kaggle  泰坦尼克
		
入门kaggle,开始机器学习应用之旅. 参看一些入门的博客,感觉pandas,sklearn需要熟练掌握,同时也学到了一些很有用的tricks,包括数据分析和机器学习的知识点.下面记录一些有趣的数据 ...
 - Kaggle泰坦尼克数据科学解决方案
		
原文地址如下: https://www.kaggle.com/startupsci/titanic-data-science-solutions --------------------------- ...
 - 逻辑回归应用之Kaggle泰坦尼克之灾(转)
		
正文:14pt 代码:15px 1 初探数据 先看看我们的数据,长什么样吧.在Data下我们train.csv和test.csv两个文件,分别存着官方给的训练和测试数据. import pandas ...
 - Spark学习笔记——泰坦尼克生还预测
		
package kaggle import org.apache.spark.SparkContext import org.apache.spark.SparkConf import org.apa ...
 - python__画图表可参考(转自:寒小阳 逻辑回归应用之Kaggle泰坦尼克之灾)
		
出处:http://blog.csdn.net/han_xiaoyang/article/details/49797143 2.背景 2.1 关于Kaggle 我是Kaggle地址,翻我牌子 亲,逼格 ...
 - Kaggle泰坦尼克-Python(建模完整流程,小白学习用)
		
参考Kernels里面评论较高的一篇文章,整理作者解决整个问题的过程,梳理该篇是用以了解到整个完整的建模过程,如何思考问题,处理问题,过程中又为何下那样或者这样的结论等! 最后得分并不是特别高,只是到 ...
 - Kaggle_泰坦尼克乘客存活预测
		
转载 逻辑回归应用之Kaggle泰坦尼克之灾 此转载只为保存!!! ————————————————版权声明:本文为CSDN博主「寒小阳」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附 ...
 
随机推荐
- dfs实现数的全排列
			
代码 #include<bits/stdc++.h> using namespace std; #define ll long long bool vis[15]; int a[15]; ...
 - iframe跨域解决方案
			
公司某个功能用的是iframe,由于跨域的原因,我们不能直接设置父级页面iframe的高度,所以用了一个中间页home来完成父级页面iframe的高度设置,这种中间页其实很多时候不好用,因为涉及到页面 ...
 - Python2和Python3中urllib库中urlencode的使用注意事项
			
前言 在Python中,我们通常使用urllib中的urlencode方法将字典编码,用于提交数据给url等操作,但是在Python2和Python3中urllib模块中所提供的urlencode的包 ...
 - springBoot项目启动类启动无法访问
			
springBoot项目启动类启动无法访问. 网上也查了一些资料,我这里总结.下不来虚的,也不废话. 解决办法: 1.若是maven项目,则找到右边Maven Projects --->Plug ...
 - 06_Hadoop分布式文件系统HDFS架构讲解
			
mr 计算框架 假如有三台机器 统领者master 01 02 03 每台机器都有过滤的应用程序 移动数据 01机== 300M >mr 移动计算 java程序传递给各个机器(mr) ...
 - Python之random模块
			
random模块 产生随机数的模块 是Python的标准模块,直接导入即可 import random 1)随机取一个整数,使用.randint()方法: import random print(ra ...
 - Tomcat connecttimeout sessiontimeout
			
IIS中的会话超时和连接超时之间有什么区别? | Adept Technologies Inc.https://www.adepttech.com/blog/?p=825 IIS中的会话超时和连接超时 ...
 - js 通过url获取里面的参数值
			
场景描述:当我们从一个页面要带有一两个值跳转到另一个页面,另一个页面要使用这些参数的时候,我们就需要通过js获取这些参数啦. 先贴上代码: function getQueryString(name) ...
 - [转帖]IP地址、子网掩码、网络号、主机号、网络地址、主机地址以及ip段/数字-如192.168.0.1/24是什么意思?
			
IP地址.子网掩码.网络号.主机号.网络地址.主机地址以及ip段/数字-如192.168.0.1/24是什么意思? 2016年03月26日 23:38:50 JeanCheng 阅读数:105674 ...
 - PL/SQL如何调试sql语句、存储过程
			
一直以来,我总是在sql的工具,比如sql server.navicat等中执行sql语句来发现问题自己写的sql中的问题,结果被问起时,让人贻笑大方! 那么如何调试成白行的存储过程?如何调试成百行s ...