最近被保研的事情搞的头大,拖了半天才勉强算结束这个了。从熟悉unbantu 16.04的环境(搭个翻墙的梯子都搞了一上午 呸!)到搭建python,pytorch环境。然后花了一个上午熟悉py的基本语法就开始强撸了,具体的过程等保研结束了再补吧,贴个代码意思一下先。

数据集用的是清洗过的MS-Celeb-1M(em...怎么清洗的之后再补吧)

python用的不熟,踩了很多坑,用pytorch的时候也是,打死不看python中文的官方文档(http://pytorch-cn.readthedocs.io/zh/latest/package_references/functional/ 真香)

后续的慢慢补吧很多细节也在完善....


import torch
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image
from torch.optim import lr_scheduler
# ------------------ ready for the dataset ------------------
transform = transforms.Compose([
transforms.Scale(227),
transforms.CenterCrop(227),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) class MyDateset (data.Dataset): def __init__(self,data_txt,transform):
imgs = []
with open(data_txt , 'r') as f:
for line in f:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
labelList = int(words[1])
imageList = words[0]
imgs.append((imageList, labelList)) self.transform = transform
self.imgs = imgs def __getitem__(self, index):
image_dir,target = self.imgs[index]
image = Image.open(image_dir)
image = transform(image) return image, target def __len__(self):
return len(self.imgs) train_data = MyDateset("/home/fuckman/FaceImage/Traindata.txt", transform)
train_loader = DataLoader(train_data,batch_size = 128 ,shuffle=True,num_workers= 8,drop_last=False) # for img,label in train_data:
# print(img.size(),label) text_data = MyDateset("/home/fuckman/FaceImage/Testdata.txt", transform)
test_loader = DataLoader(dataset=text_data, batch_size = 128 ,shuffle=False, num_workers= 8, drop_last=False) # print(train_data.__len__()) # --------------- creat the net and train -------------------- class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(3, 96, 11, 4, 0),
torch.nn.ReLU(),
torch.nn.MaxPool2d(3, 2)
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(96, 256, 5, 1, 2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(3, 2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(256, 384, 3, 1, 1),
torch.nn.ReLU(),
)
self.conv4 = torch.nn.Sequential(
torch.nn.Conv2d(384, 384, 3, 1, 1),
torch.nn.ReLU(),
)
self.conv5 = torch.nn.Sequential(
torch.nn.Conv2d(384, 256, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(3, 2)
)
self.dense = torch.nn.Sequential(
torch.nn.Dropout(0.5),
torch.nn.Linear(9216, 4096),
torch.nn.ReLU(),
torch.nn.Dropout(0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(),
torch.nn.Linear(4096,1000)
) def forward(self, x): conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
conv4_out = self.conv4(conv3_out)
conv5_out = self.conv5(conv4_out)
res = conv5_out.view(conv5_out.size(0), -1)
out = self.dense(res)
return out alexnet = Net()
alexnet.load_state_dict(torch.load('net_params.pkl'))
alexnet.cuda() # print( alexnet ) #----------------- training ---------------- # crossentryopyloss
criterion = nn.CrossEntropyLoss() # SGD with momentum
optimizer = optim.SGD(alexnet.parameters(),lr = 0.01, momentum = 0.9) # learning rate decay
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10,60], gamma=0.1) # training
for epoch in range(100):
scheduler.step()
running_loss = 0.0
for i, data in enumerate(train_loader,0):
inputs, labels = data
# print(inputs.size())
# labels have to be longTensor
# inputs, labels = Variable(inputs),Variable(labels).long()
inputs, labels = Variable(inputs.cuda()),Variable(labels.cuda()).long() optimizer.zero_grad()
# inputs should be n * c * w * h n refer to the mini-batch c refer to the num of channel
outputs = alexnet(inputs)
# print(outputs)
criterion.cuda() # outputs should be N*C labels should be N N refer to the mini-batch c refer to the num of class
# print the size of outputs and labels may help you find the question
loss = criterion(outputs, labels)
loss.backward()
optimizer.step() running_loss += loss.data[0]
if i % 100 == 99:
print('[%d, %5d] loss : %.3f' %(epoch+1,i+1,running_loss / 100))
running_loss = 0.0
if epoch % 10 == 9:
torch.save(alexnet.state_dict(), 'net_params.pkl')
print("success") print("Finished Training")
torch.save(alexnet.state_dict(), 'net_params.pkl') # ----------------- Test ------------------- correct =0
total = 0
for i, data in enumerate(test_loader,0):
images, labels = data
labels = labels.cuda()
# outputs = alexnet(Variable(images))
outputs = alexnet(Variable(images.cuda()))
_, predicted = torch.max(outputs.data, 1) # max_value and the index of max
total += labels.size(0)
new_label = labels.int()
#print(predicted)
#print(labels)
new_predic = predicted.int()
correct += (new_predic == new_label).sum() print('Accuracy of the network on the 1000 test images: %d %%' % (100 * correct / total))
 

unbuntu 16.04 MS-Celeb-1M + alexnet + pytorch的更多相关文章

  1. Unbuntu 16.04 英文环境安装中文输入法

    ubuntu 16.04 使用的是ibus输入系统,没有预装中文输入法,你要自己安装一下.以中文拼音输入法为例:1.sudo apt install ibus-pinyin2.sudo apt ins ...

  2. 如何在unbuntu 16.04上离线部署openssh

    背景:由于部署环境不能联网,为了方便文件传输,需要用到openssh.故实施步骤是,先在可以联网机器上下载离线包,然后用U盘拷贝到部署环境中. 第一步:下载离线包,下载网址:https://packa ...

  3. unbuntu 16.04.2 安装 Eclipse C++开发环境

    1.安装JAVA (1)首先添加源: sudo gedit /etc/apt/sources.list 在打开的文件中添加如下内容并保存: deb http://ppa.launchpad.net/w ...

  4. 如何在unbuntu 16.04上在线安装vsftpd

    本文涉及命令如下: # service vsftpd status //查询vsftp服务状态 # apt-get remove vsftpd //卸载vsftpd # apt-get install ...

  5. Ubuntu 16.04上源码编译和安装pytorch教程,并编写C++ Demo CMakeLists.txt | tutorial to compile and use pytorch on ubuntu 16.04

    本文首发于个人博客https://kezunlin.me/post/54e7a3d8/,欢迎阅读最新内容! tutorial to compile and use pytorch on ubuntu ...

  6. 【转】Ubuntu 16.04安装配置TensorFlow GPU版本

    之前摸爬滚打总是各种坑,今天参考这篇文章终于解决了,甚是鸡冻\(≧▽≦)/,电脑不知道怎么的,安装不了16.04,就安装15.10再升级到16.04 requirements: Ubuntu 16.0 ...

  7. Solution: Win 10 和 Ubuntu 16.04 LTS双系统, Win 10 不能从grub启动

    今年2月份在一台装了Windows的机器上装了Unbuntu 14.04 LTS (双系统, dual-boot, 现已升级到 16.04 LTS). 然而开机时要从grub启动 Windows (选 ...

  8. [转]Ubuntu 16.04建议安装

    Ubuntu 16.04发布了,带来了很多新特性,同样也依然带着很多不习惯的东西,所以装完系统后还要进行一系列的优化. 1.删除libreoffice libreoffice虽然是开源的,但是Java ...

  9. 安装Ubuntu 16.04后要做的事

    Ubuntu 16.04发布了,带来了很多新特性,同样也依然带着很多不习惯的东西,所以装完系统后还要进行一系列的优化. 1.删除libreoffice libreoffice虽然是开源的,但是Java ...

随机推荐

  1. 计蒜客——Reversion Count

    Reversion Count 解析:题目数字的长度最大为 99,因此使用字符串处理,那么必然这些数存在某些规律.对于数 a (XYZW) 和数 b (WZYX),原式 = (1000X + 100Y ...

  2. 框架 get 请求乱码

    解决方案: 在 tomcat 配置文件中添加 URIEncoding="utf-8"

  3. MAC 隐藏功能

    finder 类: shift+ cmd + G  (去指定路径) cmd+↑ (返回) cmd+↓(打开当前选中的文件,如果没有选中的则去选中第一个) cmd+ o (打开当前选中的文件) 以下这些 ...

  4. ubuntu之路——day13 只用python的numpy在较为底层的阶段实现单隐含层神经网络

    首先感谢这位博主整理的Andrew Ng的deeplearning.ai的相关作业:https://blog.csdn.net/u013733326/article/details/79827273 ...

  5. mysql增删改查sql语句

    未经允许,禁止转载!!!未经允许,禁止转载!!! 创建表   create table 表名删除表    drop table 表名修改表名   rename table 旧表名 to 新表名字创建数 ...

  6. 针对nginx,来具体聊聊正向代理与反向代理 (转载)

    https://www.sohu.com/a/235704408_468627 先来说说什么是代理服务器? 所谓代理服务器就是位于发起请求的客户端与原始服务器端之间的一台跳板服务器,正向代理可以隐藏客 ...

  7. 接口测试中模拟post四种请求数据

    https://www.jianshu.com/p/3b6d7aa2043a 一.背景介绍 在日常的接口测试工作中,模拟接口请求通常有两种方法,fiddler模拟和HttpClient模拟. Fidd ...

  8. [转]详解Linux(centos7)下安装OpenSSL安装图文方法

    OpenSSL是一个开源的ssl技术,由于我需要使用php相关功能,需要获取https的文件所以必须安装这个东西了,下面我整理了两种关于OpenSSL安装配置方法. 安装环境:  操作系统:CentO ...

  9. Shell获取字符串长度的多种方法总结

    摘自:https://www.jb51.net/article/121290.htm 前言 我们在日常工作中,对于求字符串操作在shell脚本中很常用,实现的方法有很多种,下面就来给大家归纳.汇总了求 ...

  10. (CSDN迁移) 输入一个链表,从尾到头打印链表每个节点的值

    题目描述 输入一个链表,从尾到头打印链表每个节点的值. 思路1. 翻转链表,使用java自带的翻转函数或者从头到尾依次改变链表的节点指针 /** * public class ListNode { * ...