pytorch人脸识别——自己制作数据集
这是一篇面向新手的博文:因为本人也是新手,记录一下自己在做这个项目遇到的大大小小的坑。
按照下面的例子写就好了
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision.utils import make_grid
from torch.utils.data import DataLoader transform = T.Compose( #归一化
[
T.Resize(1000),
T.CenterCrop(1000),
T.ToTensor(),
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
]
) #在这个class里重写你的数据集
class twoface(data.Dataset):
def __init__(self,root,transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root,img) for img in imgs]
self.transforms = transforms def __getitem__(self, index):
img_path = self.imgs[index]
label = 1 if 'empty' in img_path.split('/')[-1] else 0 #定义标签:图片名中有label
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data,label def __len__(self):
return len(self.imgs) dataset = twoface('./data1/',transforms=transform) #transform 用在这里
data_loader_train = t.utils.data.DataLoader(dataset=dataset,
batch_size=1, #batch_size 要=1
shuffle=True,
num_workers=0) data_loader_test = t.utils.data.DataLoader(dataset=dataset,
batch_size=1,
shuffle=True,
num_workers=0) class Model(t.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = t.nn.Sequential(t.nn.Conv2d(3, 6, kernel_size=20, stride=10, padding=0), #模型这块如果你图片格式不对的话要重新算
t.nn.ReLU(),
t.nn.Conv2d(6, 10, kernel_size=6, stride=1, padding=0),
t.nn.ReLU(),
t.nn.Conv2d(10, 16, kernel_size=5, stride=1, padding=0),
t.nn.ReLU(),
t.nn.MaxPool2d(stride=5, kernel_size=5))
#
self.dense = t.nn.Sequential(t.nn.Linear(18 * 18 * 16, 33),
t.nn.ReLU(),
t.nn.Linear(33, 2)
) def forward(self, x):
x = self.conv1(x)
x = x.view(-1,18 * 18 * 16)
x = self.dense(x)
return x model = Model().cuda()
print(model) cost = t.nn.CrossEntropyLoss()
optimizer = t.optim.Adam(model.parameters())
n_epochs = 5 for epoch in range(n_epochs):
running_loss = 0.0
running_correct = 0
print("Epoch {}/{}".format(epoch, n_epochs))
print("-" * 10) for datas in data_loader_train:
realone,label = datas
realone, label = Variable(realone).cuda(), Variable(label).cuda()
output = model(realone)
_,pred = t.max(output.data,1)
optimizer.zero_grad()
loss = cost(output, label)
loss.backward()
optimizer.step()
running_loss += loss.data[0]
running_correct += t.sum(pred == label) testing_correct = 0
for datak in data_loader_test:
X_test, y_test = datak
X_test, y_test = Variable(X_test).cuda(), Variable(y_test).cuda()
outputs = model(X_test)
_, pred = t.max(outputs.data, 1)
testing_correct += t.sum(pred == y_test.data)
print(
"Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}".format(running_loss / len(dataset),
100 * running_correct / len(
dataset),
100 * testing_correct / len(
dataset)))
t.save(model, 'ifempty.pkl')
运行model中的坑
from pylab import plt
import torch as t
from torch.autograd import Variable
from torchvision.utils import make_grid
from PIL import Image
import numpy as np
import os class Model(t.nn.Module): #如果不把这些乱七八糟的class和transform 重写一遍会出错 def __init__(self):
super(Model, self).__init__()
self.conv1 = t.nn.Sequential(t.nn.Conv2d(3, 6, kernel_size=20, stride=10, padding=0),
t.nn.ReLU(),
t.nn.Conv2d(6, 10, kernel_size=6, stride=1, padding=0),
t.nn.ReLU(),
t.nn.Conv2d(10, 16, kernel_size=5, stride=1, padding=0),
t.nn.ReLU(),
t.nn.MaxPool2d(stride=5, kernel_size=5))
#
self.dense = t.nn.Sequential(t.nn.Linear(18 * 18 * 16, 33),
t.nn.ReLU(),
t.nn.Linear(33, 2)
) def forward(self, x):
x = self.conv1(x)
x = x.view(-1,18 * 18 * 16)
x = self.dense(x)
return x lst = os.listdir('./recv/')
flag = 0
for i in range(len(lst)): img=Image.open('./recv/' + lst[i]) from torchvision import transforms as T
trans= T.Compose(
[
T.Resize(1000),
T.CenterCrop(1000),
T.ToTensor(),
T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
]
) a=trans(img)
b = a.numpy()
x = np.array([b])
y = t.Tensor(x)
fix_noise = Variable(y)
fix_noise = fix_noise.cuda()
net = t.load('ifempty.pkl')
output = net(fix_noise)
_,pred = t.max(output.data,1) if pred.data.item() == 1:
print('空')
flag = 1
else:
flag = 0 if flag == 0:
netj = t.load('ifzwh.pkl')
output2 = netj(fix_noise)
_, pred2 = t.max(output2.data, 1)
if pred2.data.item() == 0:
print('xxx')
else:
print('其他人')
pytorch人脸识别——自己制作数据集的更多相关文章
- pytorch 读数据接口 制作数据集 data.dataset
[吐槽] 啊,代码,你这个大猪蹄子 自己写了cifar10的数据接口,跟官方接口load的数据一样, 沾沾自喜,以为自己会写数据接口了 几天之后,突然想,自己的代码为啥有点慢呢,这数据集不大啊 用了官 ...
- pgm revert转换 成jpg 人脸识别图片
最近在搞人脸识别,下载数据集走得比较心累.很多数据集太大了.没有啥标签.先搞一个小的玩玩.还找到的是pgm灰度图.索性写了个小脚本,用来转换.同时写脚本打标签. 数据集地址:http://downlo ...
- olivettifaces数据集实现人脸识别代码
数据集: # -*- coding: utf-8 -*- """ Created on Wed Apr 24 18:21:21 2019 @author: 92958 & ...
- opencv人脸识别提取手机相册内人物充当数据集,身份识别学习(草稿)
未写完 采用C++,opencv+opencv contrib 4.1.0 对手机相册内人物opencv人脸识别,身份识别学习 最近事情多,介绍就先不介绍了 photocut.c #include & ...
- Yale数据库上的人脸识别
一.问题分析 1. 问题描述 在Yale数据集上完成以下工作:在给定的人脸库中,通过算法完成人脸识别,算法需要做到能判断出测试的人脸是否属于给定的数据集.如果属于,需要判断出测试的人脸属于数据集中的哪 ...
- 使用OpenCV和Python进行人脸识别
介绍 人脸识别是什么?或识别是什么?当你看到一个苹果时,你的大脑会立刻告诉你这是一个苹果.在这个过程中,你的大脑告诉你这是一个苹果水果,用简单的语言来说就是识别.那么什么是人脸识别呢?我肯定你猜对了. ...
- Github开源人脸识别项目face_recognition
Github开源人脸识别项目face_recognition 原文:https://www.jianshu.com/p/0b37452be63e 译者注: 本项目face_recognition是一个 ...
- OpenCV 和 Dlib 人脸识别基础
00 环境配置 Anaconda 安装 1 下载 https://repo.anaconda.com/archive/ 考虑到兼容性问题,推荐下载Anaconda3-5.2.0版本. 2 安装 3 测 ...
- Opencv摄像头实时人脸识别
Introduction 网上存在很多人脸识别的文章,这篇文章是我的一个作业,重在通过摄像头实时采集人脸信息,进行人脸检测和人脸识别,并将识别结果显示在左上角. 利用 OpenCV 实现一个实时的人脸 ...
随机推荐
- SNMP学习笔记之SNMP报文协议详解
0x00 简介 简单网络管理协议(SNMP)是TCP/IP协议簇的一个应用层协议.在1988年被制定,并被Internet体系结构委员会(IAB)采纳作为一个短期的网络管理解决方案:由于SNMP的简单 ...
- Jquery 给Js动态新添加的元素 绑定的点击事件
//one $('.class').on("click",function(){ alert('one') }); //相当于$('.class').bind("clic ...
- JQuery判断input是否被禁用
<script src="jquery.min.js"></script> <br/><input type="text&quo ...
- 05: MySQL高级查询
MySQL其他篇 目录: 参考网站 1.1 GROUP BY分组使用 1.2 mysql中NOW(),CURDATE(),CURTIME()的使用 1.3 DATEDIFF() 函数 1.4 DATE ...
- CentOS7.3防火墙firewalld简单配置
今天安装了centos7.3, 想用iptables的save功能保存规则的时候发现跟rhel不一样了, 后来度娘说centos用的是firewalld而不是iptables了, 平时工作都是用re ...
- STM32.SPI(25Q16)
1.首先认识下W25Q16DVSIG, SOP8 SPI FLASH 16MBIT 2MB(4096个字节) (里面可以放字库,图片,也可以程序掉电不丢失数据放里面) 例程讲解: ① 1.用到SPI ...
- Python3基础 file with 配合文件操作
Python : 3.7.0 OS : Ubuntu 18.04.1 LTS IDE : PyCharm 2018.2.4 Conda ...
- Java8的新特性,二进制序列转十进制数字
package kata_007_二进制序列转十进制int; /** * java8 Lambda表达式转换binary序列->十进制数 */ import java.util.ArrayLis ...
- 51nod 1412 AVL树的种类(经典dp)
http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1412 题意: 思路: 经典dp!!!可惜我想不到!! $dp[i][k] ...
- HDU 2955 Robberies(0-1背包)
http://acm.hdu.edu.cn/showproblem.php?pid=2955 题意:一个抢劫犯要去抢劫银行,给出了几家银行的资金和被抓概率,要求在被抓概率不大于给出的被抓概率的情况下, ...