学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)
视觉问题引入深度神经网络后,针对端对端的训练和预测网络,可以看是特征的表达和任务的决策问题(分类,回归等)。当我们自己的训练数据量过小时,往往借助牛人已经预训练好的网络进行特征的提取,然后在后面加上自己特定任务的网络进行调优。目前,ILSVRC比赛(针对1000类的分类问题)所使用数据的训练集126万张图像,验证集5万张,测试集10万张(标注未公布),大家一般使用这个比赛的前几名的网络来搭建自己特定任务的神经网络。
本篇博文主要简单讲述怎么使用TensorFlow调用预训练好的VGG网络,其他的网络(如Alex, ResNet等)也是同样的套路。分为三个部分:第一部分下载网络架构定义以及权重参数,第二部分是如何调用预训练网络中的feature
map,第三部分给出参考资料。注:资料是学习查找整理而得,理解有误的地方,请多多指正~
一、下载网络架构定义以及权重参数
https://github.com/leihe001/tensorflow-vgg
训练和测试网络的定义
https://mega.nz/#!YU1FWJrA!O1ywiCS2IiOlUCtCpI6HTJOMrneN-Qdv3ywQP5poecM VGG16
https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs
VGG19
二、调用预训练网络中的feature map(以VGG16为例)
import inspect
import os
import numpy as np
import tensorflow as tf
import time
VGG_MEAN = [103.939, 116.779, 123.68]
class Vgg16:
def __init__(self, vgg16_npy_path=None):
if vgg16_npy_path is None:
path = inspect.getfile(Vgg16)
path = os.path.abspath(os.path.join(path, os.pardir))
path = os.path.join(path, "vgg16.npy")
vgg16_npy_path = path
print path
# 加载网络权重参数
self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
print("npy file loaded")
def build(self, rgb):
"""
load variable from npy to build the VGG
:param rgb: rgb image [batch, height, width, 3] values scaled [0, 1]
"""
start_time = time.time()
print("build model started")
rgb_scaled = rgb * 255.0
# Convert RGB to BGR
red, green, blue = tf.split(3, 3, rgb_scaled)
assert red.get_shape().as_list()[1:] == [224, 224, 1]
assert green.get_shape().as_list()[1:] == [224, 224, 1]
assert blue.get_shape().as_list()[1:] == [224, 224, 1]
bgr = tf.concat(3, [
blue - VGG_MEAN[0],
green - VGG_MEAN[1],
red - VGG_MEAN[2],
])
assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
self.conv1_1 = self.conv_layer(bgr, "conv1_1")
self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
self.pool1 = self.max_pool(self.conv1_2, 'pool1')
self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
self.pool2 = self.max_pool(self.conv2_2, 'pool2')
self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
self.pool3 = self.max_pool(self.conv3_3, 'pool3')
self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
self.pool4 = self.max_pool(self.conv4_3, 'pool4')
self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
self.pool5 = self.max_pool(self.conv5_3, 'pool5')
self.fc6 = self.fc_layer(self.pool5, "fc6")
assert self.fc6.get_shape().as_list()[1:] == [4096]
self.relu6 = tf.nn.relu(self.fc6)
self.fc7 = self.fc_layer(self.relu6, "fc7")
self.relu7 = tf.nn.relu(self.fc7)
self.fc8 = self.fc_layer(self.relu7, "fc8")
self.prob = tf.nn.softmax(self.fc8, name="prob")
self.data_dict = None
print("build model finished: %ds" % (time.time() - start_time))
def avg_pool(self, bottom, name):
return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
def max_pool(self, bottom, name):
return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
def conv_layer(self, bottom, name):
with tf.variable_scope(name):
filt = self.get_conv_filter(name)
conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
conv_biases = self.get_bias(name)
bias = tf.nn.bias_add(conv, conv_biases)
relu = tf.nn.relu(bias)
return relu
def fc_layer(self, bottom, name):
with tf.variable_scope(name):
shape = bottom.get_shape().as_list()
dim = 1
for d in shape[1:]:
dim *= d
x = tf.reshape(bottom, [-1, dim])
weights = self.get_fc_weight(name)
biases = self.get_bias(name)
# Fully connected layer. Note that the '+' operation automatically
# broadcasts the biases.
fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
return fc
def get_conv_filter(self, name):
return tf.constant(self.data_dict[name][0], name="filter")
def get_bias(self, name):
return tf.constant(self.data_dict[name][1], name="biases")
def get_fc_weight(self, name):
return tf.constant(self.data_dict[name][0], name="weights")
以上是VGG16网络的定义,假设我们现在输入图像image,打算做分割,那么我们可以使用端对端的全卷积网络进行训练和测试。针对这个任务,我们只需要输出pool5的feature map即可。
#以上你的网络定义,初始化方式,以及数据预处理... vgg = vgg16.Vgg16() vgg.build(image) feature_map = vgg.pool5 mask = yournetwork(feature_map) #以下定义loss,学习率策略,然后train...
三、参考资料
https://github.com/leihe001/tensorflow-vgg
https://github.com/leihe001/tfAlexNet
https://github.com/leihe001/tensorflow-resnet
学习TensorFlow,调用预训练好的网络(Alex, VGG, ResNet etc)的更多相关文章
- TensorFlow 调用预训练好的模型—— Python 实现
1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...
- tensorflow 使用预训练好的模型的一部分参数
vars = tf.global_variables() net_var = [var for var in vars if 'bi-lstm_secondLayer' not in var.name ...
- 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)
转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...
- zz从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史
从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 Bert最近很火,应该是最近最火爆的AI进展,网上的评价很高,那么Bert值得这么高的评价吗?我个人判断是值得.那为什么 ...
- NLP之预训练
内容是结合:https://zhuanlan.zhihu.com/p/49271699 可以直接看原文 预训练一般要从图像处理领域说起:可以先用某个训练集合比如训练集合A或者训练集合B对这个网络进行预 ...
- LUSE: 无监督数据预训练短文本编码模型
LUSE: 无监督数据预训练短文本编码模型 1 前言 本博文本应写之前立的Flag:基于加密技术编译一个自己的Python解释器,经过半个多月尝试已经成功,但考虑到安全性问题就不公开了,有兴趣的朋友私 ...
- 第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)
在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...
随机推荐
- No mapping found for HTTP request with URI [/user/login.do] in DispatcherServlet with name 'dispatcher'错误
1.警告的相关信息 七月 24, 2017 3:53:04 下午 org.springframework.web.servlet.DispatcherServlet noHandlerFound警告: ...
- C#使用AutoMapper6.2.2.0进行对象映射
先说说DTO DTO是个什么东东? DTO(Data Transfer Object)就是数据传输对象,说白了就是一个对象,只不过里边全是数据而已. 为什么要用DTO? 1.DTO更注重数据,对领域对 ...
- [LeetCode] Falling Squares 下落的方块
On an infinite number line (x-axis), we drop given squares in the order they are given. The i-th squ ...
- [LeetCode] Add One Row to Tree 二叉树中增加一行
Given the root of a binary tree, then value v and depth d, you need to add a row of nodes with value ...
- 实验吧_NSCTF web200&FALSE(代码审计)
挺简单的一个代码审计,这里只要倒序解密就行了,这里给一下python版的wp import codecs import base64 strs = 'a1zLbgQsCESEIqRLwuQAyMwLy ...
- 空间漫游(SAC大佬的测试)
题目描述由于球哥和巨佬嘉诚交了很多保护费,我们有钱进行一次 d 维空间漫游.d 维空间中有 d 个正交坐标轴,可以用这些坐标轴来描述你在空间中的位置和移动的方向.例如,d = 1 时,空间是一个数轴, ...
- 计蒜客NOIP模拟赛(2) D2T1 劫富济贫
[问题描述] 吕弗·普自小从英国长大,受到骑士精神的影响,吕弗·普的梦想便是成为一位劫富济贫的骑士. 吕弗·普拿到了一份全国富豪的名单(不在名单上的都是穷人),上面写着所有富豪的名字以及他们的总资产, ...
- [SCOI2012]滑雪与时间胶囊
题目描述 a180285非常喜欢滑雪.他来到一座雪山,这里分布着MMM条供滑行的轨道和NNN个轨道之间的交点(同时也是景点),而且每个景点都有一编号iii(1≤i≤N1 \le i \le N1≤i≤ ...
- [bzoj4881][Lydsy2017年5月月赛]线段游戏
来自FallDream的博客,未经允许,请勿转载,谢谢. quailty和tangjz正在玩一个关于线段的游戏.在平面上有n条线段,编号依次为1到n.其中第i条线段的两端点坐标分别为(0,i)和(1, ...
- python (3.5)字符串 持续更新中………………
# 字符串与变量连接输出 name = input("请输入姓名")age = input("请输入年龄")job = input("请输入工作&qu ...