# -*- coding:utf-8 -*-

import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) torch.manual_seed(1) # 设置随机种子
rmb_label = {"1": 0, "100": 1} # 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1 output_dir = os.path.join(BASE_DIR, "..", "..", "Result", "backward_hook_cam") fmap_block = list()
input_block = list()
# ============================ step 1/5 数据 ============================ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train")
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225] def backward_hook(module, grad_in, grad_out):
grad_block.append(grad_out[0].detach()) def farward_hook(module, input, output):
fmap_block.append(output) def show_cam_on_image(img, mask, out_dir):
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
path_cam_img = os.path.join(out_dir, "cam1.jpg")
path_raw_img = os.path.join(out_dir, "raw1.jpg")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
print(cam)
cv2.imwrite(path_cam_img, np.uint8(255 * cam))
cv2.imwrite(path_raw_img, np.uint8(255 * img)) def comp_class_vec(ouput_vec, index=None):
"""
计算类向量
:param ouput_vec: tensor
:param index: int,指定类别
:return: tensor
"""
if not index:
index = np.argmax(ouput_vec.cpu().data.numpy())
else:
index = np.array(index)
index = index[np.newaxis, np.newaxis]
index = torch.from_numpy(index)
one_hot = torch.zeros(1, 2).scatter_(1, index, 1)
one_hot.requires_grad = True
class_vec = torch.sum(one_hot * outputx) # one_hot = 11.8605
return class_vec def gen_cam(feature_map, grads):
"""
依据梯度和特征图,生成cam
:param feature_map: np.array, in [C, H, W]
:param grads: np.array, in [C, H, W]
:return: np.array, [H, W]
"""
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W) weights = np.mean(grads, axis=(1, 2)) # for i, w in enumerate(weights):
cam += w * feature_map[i, :, :] cam = np.maximum(cam, 0)
cam = cv2.resize(cam, (64, 64))
cam -= np.min(cam)
cam /= np.max(cam) return cam train_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomCrop(64, padding=4),
transforms.RandomGrayscale(p=0.8),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) valid_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
]) # 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform) # 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # ============================ step 2/5 模型 ============================ net = LeNet(classes=2)
net.initialize_weights() # ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss() # 选择损失函数 # ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 设置学习率下降策略 # ============================ step 5/5 训练 ============================
train_curve = list() iter_count = 0 for epoch in range(MAX_EPOCH):
fmap_dict = dict() loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
iter_count += 1
# forward
inputs, labels = data
outputs = net(inputs)
# backward optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward() # update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum().numpy()
# 打印训练信息
loss_mean += loss.item()
train_curve.append(loss.item())
if (i+1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
loss_mean = 0. scheduler.step() # 更新学习率
img = cv2.imread('100.jpg', 1) # H*W*C
x = Image.open('100.jpg').convert('RGB')
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
valid_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
x = valid_transform(x)
x.unsqueeze_(0) net.conv2.register_forward_hook(farward_hook)
net.conv2.register_backward_hook(backward_hook)
outputx = net(x)
net.zero_grad()
class_loss = comp_class_vec(outputx)
class_loss.backward()
grads_val = grad_block[0].cpu().data.numpy().squeeze()
fmap = fmap_block[0].cpu().data.numpy().squeeze()
cam = gen_cam(fmap, grads_val)
img_show = np.float32(cv2.resize(img, (64, 64))) / 255
show_cam_on_image(img_show, cam, output_dir)



[个人总结]利用grad-cam实现人民币分类的更多相关文章

  1. 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习笔记,这次是第7章 - 利用AdaBoost元算法提高分类性能. 核心思想 在使用某个特定的算法是, ...

  2. 【转载】 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能

    原文地址: https://www.cnblogs.com/steven-yang/p/5686473.html ------------------------------------------- ...

  3. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  4. 利用RNN进行中文文本分类(数据集是复旦中文语料)

    利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) 1.训练词向量 数据预处理参考利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) ,现在我们有了分词 ...

  5. 利用CNN进行中文文本分类(数据集是复旦中文语料)

    利用TfidfVectorizer进行中文文本分类(数据集是复旦中文语料) 利用RNN进行中文文本分类(数据集是复旦中文语料) 上一节我们利用了RNN(GRU)对中文文本进行了分类,本节我们将继续使用 ...

  6. 利用AdaBoost元算法提高分类性能

    当做重要决定时,大家可能都会吸取多个专家而不只是一个人的意见.机器学习处理问题时又何尝不是如此?这就是元算法背后的思路.元算法是对其他算法进行组合的一种方式. 自举汇聚法(bootstrap aggr ...

  7. 【Python与机器学习】:利用Keras进行多类分类

    多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...

  8. 利用Spark-mllab进行聚类,分类,回归分析的代码实现(python)

    Spark作为一种开源集群计算环境,具有分布式的快速数据处理能力.而Spark中的Mllib定义了各种各样用于机器学习的数据结构以及算法.Python具有Spark的API.需要注意的是,Spark中 ...

  9. 利用logistic回归解决多分类问题

    利用logistic回归解决手写数字识别问题,数据集私聊. from scipy.io import loadmat import numpy as np import pandas as pd im ...

随机推荐

  1. 【uva 12174】Shuffle(算法效率--滑动窗口)

    题意:假设一种音乐播放器有一个乱序的功能,设定每播放S首歌为一个周期,随机播放编号为1~S的歌曲.现在给一个长度为N的部分播放记录,请统计下次随机排序所发生的时间的可能性种数.(1≤S,N≤10000 ...

  2. 大数据去重(data deduplication)方案

    数据去重(data deduplication)是大数据领域司空见惯的问题了.除了统计UV等传统用法之外,去重的意义更在于消除不可靠数据源产生的脏数据--即重复上报数据或重复投递数据的影响,使计算产生 ...

  3. OpenStack-知识点补充

    登录计算节点查看进程 [root@compute ~]# ps aux | grep kvm root 824 0.0 0.0 0 0 ? S< 10:19 0:00 [kvm-irqfd-cl ...

  4. LeetCode6 Z字形排列

    题目描述是从上到下,从左到右Z字形排列. 找规律.这种形式一般都是mod x 余数有规律.然后写的时候围绕x构造,而非判断,代码会简单一些. 设行数为r 先观察r=5的情况 发现第0行的字符原始ind ...

  5. Creative Commons : CC (知识共享署名 授权许可)

    1 https://creativecommons.org/      Keep the internet creative, free and open. Creative Commons help ...

  6. taro ENV & NODE_ENV & process.env

    taro ENV & NODE_ENV & process.env https://github.com/NervJS/taro-ui/blob/dev/src/common/util ...

  7. c++ 获取和设置 窗口标题

    有些窗口标题中含有中文,可能识别不出来,可以改一下 SetWindowTextW GetWindowTextW 设置: SetWindowTextW( (HWND)0x00370868/*窗口句柄*/ ...

  8. 10月份上线的NGK有什么不同之处?

    近日,有小道消息传出公链项目NGK即将在10月上线的消息.各大社区纷纷开始布局,市场中关于NGK项目的消息也变得更多了起来.仅是社区热度这一点,对比之下就已经优于很多项目,那么是否还有其他优势呢?让我 ...

  9. go好用的类型转换第三方组件

    Cast介绍 开源地址 https://github.com/spf13/cast Cast是什么? Cast是一个库,以一致和简单的方式在不同的go类型之间转换. Cast提供了简单的函数,可以轻松 ...

  10. Java 添加 、读取以及删除PPT幻灯片中的视频、音频文件

    在PPT中,可以操作很多种元素,如形状.图形.文字.图片.表格等,也可以插入视频或者音频文件,来丰富幻灯片的内容呈现方式.下面将介绍在Java程序中如何来添加视频.音频文件到PPT幻灯片,读取和删除幻 ...