[个人总结]pytorch中用checkpoint设置恢复,在恢复后的acc上升
原因是因为checkpoint设置好的确是保存了相关字段。但是其中设置的train_dataset却已经走过了epoch轮,当你再继续训练时候,train_dataset是从第一个load_data开始。
# -*- coding:utf-8 -*-
import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
import torch.nn.functional as F
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
torch.manual_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1}
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 2)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool1(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
checkpoint_interval=5
# ============================ step 1/5 数据 ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train")
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
net = Net()
criterion = nn.CrossEntropyLoss() # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
checkpointdict = torch.load('./checkpoint4.pkl')
net.load_state_dict(checkpointdict["model_state_dict"])
optimizer.load_state_dict(checkpointdict["optimizer_state_dict"])
startepoch = checkpointdict["epoch"]
# ============================ step 5/5 训练 ============================
train_curve = list()
iter_count = 0
for epoch in range(startepoch+1,MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
for counti in range(6):
for i, data in enumerate(train_loader):
if counti <5:
continue
else:
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0.
# if ((epoch + 1) % checkpoint_interval == 0):
# checkpoint = {"model_state_dict": net.state_dict(),
# "optimizer_state_dict": optimizer.state_dict(),
# "epoch": epoch}
# path_checkpoint = './checkpoint{}.pkl'.format(epoch)
# torch.save(checkpoint, path_checkpoint)
# if ((epoch + 1) % 5 == 0):
# print("退出")
# break
[个人总结]pytorch中用checkpoint设置恢复,在恢复后的acc上升的更多相关文章
- PostgreSQL CheckPoint设置(转)
今天在研究checkpoint process的问题时,顺便复习了一下checkpoint设置问题,又有新的疑惑了. checkpoint又名检查点,在oracle中checkpoint的发生意味着之 ...
- 【Android】设置android:maxLines="1"后,android:imeOptions="actionSearch"失效
android:singleLine在API LEVEL 3已经废弃,可以用android:maxLines="1"代替. 但是测试的时候发现设置android:maxLines= ...
- 解决.Net设置只读、隐藏后后台获取不到值的问题
在前台页面上放了几个textbox,用 ReadOnly=true设置不可编辑,用visible="False"设置不可见 用jquery给textbox赋值后在后台页面获取不到t ...
- 元素设置position:fixed属性后IE下宽度无法100%延伸
元素设置position:fixed属性后IE下宽度无法100%延伸 IE bug 出现条件: 1.div1设置position:fixed属性,并且想要width:100%的效果. 2.div2(下 ...
- c# winform 设置winform进入窗口后在文本框里的默认焦点
c# winform 设置winform进入窗口后在文本框里的默认焦点 进入窗口后默认聚焦到某个文本框,两种方法: ①设置tabindex 把该文本框属性里的tabIndex设为0,焦点就默认在这个文 ...
- windows设置多长时间后自动关机 分类: windows常用小技巧 2014-04-15 09:35 230人阅读 评论(0) 收藏
分二步: 第一步:点击windows键,在"搜索程序和文件"的文本框输入:cmd 第二步:输入:shutdown -s -t (设置电脑一小时后自动关机) 备注: ...
- undo丢失恢复异常恢复,运维DBA反映Oracle数据库无法启动报错ORA-01157 ORA-01110,分析原因为Oracle数据库坏块导致
本文转自 惜纷飞 大师. 模拟基表事务未提交数据库crash,undo丢失恢复异常恢复,运维DBA反映Oracle数据库无法启动报错ORA-01157 ORA-01110,分析原因为Oracle数据库 ...
- 设置vue启动项目后默认显示的页面
通过配置路由,可以设置vue项目启动后默认显示的页面.路由的path设置为path:"/",启动项目后就会显示默认的组件页面. import Vue from 'vue' impo ...
- arcgis的afcore_libfnp.dll经常被360杀毒,删除,请到恢复区恢复
arcgis的afcore_libfnp.dll经常被360杀毒,删除,请到恢复区恢复
随机推荐
- hdu 3549Flow Problem
Problem Description Network flow is a well-known difficult problem for ACMers. Given a graph, your t ...
- Codeforces Round #672 (Div. 2) D. Rescue Nibel! (思维,组合数)
题意:给你\(n\)个区间,从这\(n\)区间中选\(k\)个区间出来,要求这\(k\)个区间都要相交.问共有多少种情况. 题解:如果\(k\)个区间都要相交,最左边的区间和最右边的区间必须要相交,即 ...
- JavaScript——原型
原型中的原先设定的值不能改变!!!
- 二叉树增删改查 && 程序实现
二叉排序树定义 一棵空树,或者是具有下列性质的二叉树:(1)若左子树不空,则左子树上所有结点的值均小于它的根结点的值:(2)若右子树不空,则右子树上所有结点的值均大于它的根结点的值:(3)左.右子树也 ...
- - 迷宫问题 POJ - 3984 bfs记录路径并输出最短路径
定义一个二维数组: int maze[5][5] = { 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, ...
- 如何在windows上升级Powershell到5.1版本?
前言 此篇我们说的是Powershell5.1低版本到5.1的升级,对于Powershell6(及以上版本)可以跨平台独立安装,在windows上可与之前的版本并存. 首先要整清楚Powershell ...
- 📚C#/.NET/.NET Core推荐学习书籍(升职加薪,你值得拥有)
前言: 作为一名程序员,我们无时无刻都要考虑着如何通过不断地学习来提升自己的核心竞争力.古人有云:"书中自有黄金屋,书中只有颜如玉",说明了书籍的重要性,没错工作多年来,发现身边那 ...
- 通过js正则表达式实例学习正则表达式基本语法
正则表达式又叫规则表达式,一般用来检查字符串中是否有与规则相匹配的子串,达到可以对匹配的子串进行提取.删除.替换等操作的目的.先了解有哪些方法可以使用正则对字符串来实现这些操作: RegExpObje ...
- Leetcode(38)-报数
报数序列是指一个整数序列,按照其中的整数的顺序进行报数,得到下一个数.其前五项如下: 1. 1 2. 11 3. 21 4. 1211 5. 111221 1 被读作 "one 1&quo ...
- KafkaBroker 简析
Kafka 依赖 Zookeeper 来维护集群成员的信息: Kafka 使用 Zookeeper 的临时节点来选举 controller Zookeeper 在 broker 加入集群或退出集群时通 ...