1 pandas 读 csv

import torch
from torch import nn
import numpy as np
import pandas as pd
from copy import deepcopy
device = "cuda" if torch.cuda.is_available() else "cpu" # 读 csv
data_all = pd.read_csv('./CFD_data/record_data0.csv')
# 提取某一列
colume = np.array(data_all[['colume_name']], dtype=np.float32).reshape(-1, 1)
# 提取某一个值
value = data[data['食物种类']=='主食']['卡路里'].item()
# 数据操作
c = np.concatenate([a[1:], b[:-1]], axis=1)
c = torch.cat([a, b], axis=1)
# 存 csv
c.to_csv('./CFD_data/flow_rate.csv', index=False)

2 NN 的搭建、训练与评估

搭建:使用 nn.Sequential

# model
NN_model = nn.Sequential(
nn.Linear(6, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
# 优化器
optimizer = torch.optim.Adam(NN_model.parameters(), lr=0.001)

训练:

def NN_train(train_x, train_y, model, loss_fn, optimizer, epoches, batch_size, save_path):
"""
训练网络
输入:
train_x, train_y: 训练集
model: 网络模型
loss_fn: 损失函数
optimizer: 优化器
epoches: epoches 个数
batch_size: mini batch 大小
save_path: 模型保存路径
"""
# 切换到train模式
model.train()
losses = []
for epoch in range(epoches):
batch_loss = []
for start in range(0, len(train_x), batch_size): # mini batch
end = start + batch_size if start + batch_size < len(train_x) else len(train_x)
xx = torch.tensor(train_x[start:end], dtype=torch.float, requires_grad=True)
yy = torch.tensor(train_y[start:end], dtype=torch.float, requires_grad=True)
xx, yy = xx.to(device), yy.to(device) # 加载到 device
pred = model(xx) # 输入数据到模型里得到输出
loss = loss_fn(pred, yy) # 计算输出和标签的 loss
optimizer.zero_grad() # 清零
loss.backward() # 反向推导
optimizer.step() # 步进优化器
batch_loss.append(loss.data.numpy())
if epoch % max(1, epoches//8) == 0:
print(f"Training Error in epoch {epoch}: {np.mean(batch_loss):>8f}")
torch.save(model.state_dict(), save_path) # 保存模型

测试:

def NN_test(test_x, test_y, model, save_path, loss_fn):
"""
测试网络
输入:
test_x, test_y: 测试集
model: 网络模型
loss_fn: 损失函数
save_path: 模型保存路径
"""
model.load_state_dict(torch.load(save_path)) # 加载模型
model.eval() # 切换到测试模型
MSE_loss_fn = nn.MSELoss() # MSE loss function
test_loss, MSE = 0, 0 # 记录 loss 和 MSE
# 梯度截断
with torch.no_grad():
test_x, test_y = torch.tensor(test_x).to(device), torch.tensor(test_y).to(device) # 加载到 device
pred = model(test_x) # 输入数据到模型里得到输出
test_loss = loss_fn(pred, test_y).item() # 计算输出和标签的 loss
MSE = MSE_loss_fn(pred, test_y).item() # MSE
print(f"Test Error: \n Avg loss: {test_loss:>8f}, MSE: {MSE:>8f}\n")
print(f"Test Result: \n Prediction: {pred[:5]}, \n Y: {test_y[:5]}, \n diff: {test_y[:5]-pred[:5]}\n")

测试 ensemble model(平均值):

def NN_test_ensemble(test_x, test_y, loaded_model_list, loss_fn):
for model in loaded_model_list:
model.eval() # 切换到测试模型
MSE_loss_fn = nn.MSELoss() # MSE loss function
test_loss, MSE = 0, 0 # 记录 loss 和 MSE
# 梯度截断
with torch.no_grad():
test_x, test_y = torch.tensor(test_x).to(device), torch.tensor(test_y).to(device) # 加载到 device
pred = torch.zeros(test_y.shape)
for model in loaded_model_list:
pred += model(test_x) # 输入数据到模型里得到输出
pred /= len(loaded_model_list)
test_loss = loss_fn(pred, test_y).item() # 计算输出和标签的 loss
MSE = MSE_loss_fn(pred, test_y).item() # MSE
print(f"Test Error: \n Avg loss: {test_loss:>8f}, MSE: {MSE:>8f}\n")
print(f"Test Result: \n Prediction: {pred[:5]}, \n Y: {test_y[:5]}, \n diff: {test_y[:5]-pred[:5]}\n")

打印梯度,debug:

for name, param in model.named_parameters():
print(name, param.grad)

python · pytorch | NN 训练常用代码存档的更多相关文章

  1. PyTorch常用代码段整理合集

    PyTorch常用代码段整理合集 转自:知乎 作者:张皓 众所周知,程序猿在写代码时通常会在网上搜索大量资料,其中大部分是代码段.然而,这项工作常常令人心累身疲,耗费大量时间.所以,今天小编转载了知乎 ...

  2. PyTorch 常用代码段整理

    基础配置 检查 PyTorch 版本 torch.__version__               # PyTorch version torch.version.cuda              ...

  3. 【转载】GitHub 标星 1.2w+,超全 Python 常用代码合集,值得收藏!

    本文转自逆袭的二胖,作者二胖 今天给大家介绍一个由一个国外小哥用好几年时间维护的 Python 代码合集.简单来说就是,这个程序员小哥在几年前开始保存自己写过的 Python 代码,同时把一些自己比较 ...

  4. Pytorch之训练器设置

    Pytorch之训练器设置 引言 深度学习训练的时候有很多技巧, 但是实际用起来效果如何, 还是得亲自尝试. 这里记录了一些个人尝试不同技巧的代码. tensorboardX 说起tensorflow ...

  5. Python机器学习笔记:常用评估指标的用法

    在机器学习中,性能指标(Metrics)是衡量一个模型好坏的关键,通过衡量模型输出y_predict和y_true之间的某种“距离”得出的. 对学习器的泛化性能进行评估,不仅需要有效可行的试验估计方法 ...

  6. Python实现NN(神经网络)

    Python实现NN(神经网络) 参考自Github开源代码:https://github.com/dennybritz/nn-from-scratch 运行环境 Pyhton3 numpy(科学计算 ...

  7. python爬虫:一些常用的爬虫技巧

    python爬虫:一些常用的爬虫技巧 1.基本抓取网页 get方法: post方法: 2.使用代理IP 在开发爬虫过程中经常会遇到IP被封掉的情况,这时就需要用到代理IP; 在urllib2包中有Pr ...

  8. NSIS常用代码整理

    原文 NSIS常用代码整理 这是一些常用的NSIS代码,少轻狂特意整理出来,方便大家随时查看使用.不定期更新哦~~~ 1 ;获取操作系统盘符 2 ReadEnvStr $R0 SYSTEMDRIVE ...

  9. 第六章:Python基础の反射与常用模块解密

    本课主题 反射 Mapping 介绍和操作实战 模块介绍和操作实战 random 模块 time 和 datetime 模块 logging 模块 sys 模块 os 模块 hashlib 模块 re ...

  10. Python SQLAlchemy基本操作和常用技巧包含大量实例,非常好python

    http://www.makaidong.com/%E8%84%9A%E6%9C%AC%E4%B9%8B%E5%AE%B6/28053.shtml "Python SQLAlchemy基本操 ...

随机推荐

  1. 27、Type关键字

    1.是什么? type是go语法里额重要而且常用的关键字,type绝不只是对应于C/C++中的typeof.搞清楚type的使用,就容易理解Go语言中的核心概念struct.interface.函数等 ...

  2. 打造一个极度舒适的Chrome扩展项目开发环境

    大家好,我是 dom 哥.这是我关于 Chrome 扩展开发的系列文章,感兴趣的可以 点个小星星. Chrome 扩展能够提高浏览器的使用体验,通过自定义 UI 界面,监听浏览器事件,改变 Web 页 ...

  3. Springboot快速集成阿里云RocketMq

    前言 随着互联网的兴起,越来越多的用户开始享受科技带来的便利,对于服务的压力也日益增大,随即便有了高并发.高性能.高可用等各种解决方案,这里主要介绍RocketMq的集成方法.(文末附源码地址) 正文 ...

  4. Python——第三章:函数的返回值

    函数的返回值: 函数执行之后. 会给调用方一个结果. 这个结果就是返回值 关于return:        函数只要执行到了return. 函数就会立即停止并返回内容. 函数内的return的后续的代 ...

  5. 摆脱自研难题,AUI Kit助力企业快速搭建专属互动课堂

    本专栏将分享阿里云视频云MediaBox系列技术文章,深度剖析音视频开发利器的技术架构.技术性能.开发能效和最佳实践,一起开启音视频的开发之旅.本文为MediaBox最佳实践篇,重点从互动课堂AUI ...

  6. 《An End-to-end Model for Entity-level Relation Extraction using Multi-instance Learning》阅读笔记

    代码   原文地址   预备知识: 1.什么是MIL? 多示例学习(MIL)是一种机器学习的方法,它的特点是每个训练数据不是一个单独的实例,而是一个包含多个实例的集合(称为包).每个包有一个标签,但是 ...

  7. 快速掌握服务网格系列二:云原生、K8S、服务网格(Service Mesh)及微服务之间的关系

    快速掌握服务网格系列二:云原生.K8S.服务网格(Service Mesh)及微服务之间的关系 首先看下CNCF对云原生的定义: Cloud native technologies empower o ...

  8. C#判断字符串是否是有效的XML格式数据

    说明 在try-catch语句块中,创建XmlDocument对象,并使用LoadXml方法加载xml字符串.如果没有异常,则说明xml字符串是有效的,返回true,反之为false. 代码实现 // ...

  9. LeetCode 回溯篇(46、77、78、51)

    46. 全排列 给定一个 没有重复 数字的序列,返回其所有可能的全排列. 示例: 输入: [1,2,3] 输出: [ [1,2,3], [1,3,2], [2,1,3], [2,3,1], [3,1, ...

  10. 华为云数据库GaussDB(for openGauss):初次见面,认识一下

    摘要:本文从总体架构.主打场景.关键技术特性等方面进行介绍GaussDB(for openGauss). 1.背景介绍 3月16日,在华为云主办的GaussDB(for openGauss)系列技术第 ...