tensorflow 1.0 学习:参数和特征的提取
在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如:
#取出所有参与训练的参数
params=tf.trainable_variables()
print("Trainable variables:------------------------") #循环列出参数
for idx, v in enumerate(params):
print(" param {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name))
这里只能查看参数的shape和name,并没有具体的值。如果要查看参数具体的值的话,必须先初始化,即:
sess=tf.Session()
sess.run(tf.global_variables_initializer())
同理,我们也可以提取图片经过训练后的值。图片经过卷积后变成了特征,要提取这些特征,必须先把图片feed进去。
具体看实例:
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 3 12:07:59 2017 @author: Administrator
""" import tensorflow as tf
from skimage import io,transform
import numpy as np #-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,100,100,3],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_') #第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
inputs=x,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) #第二个卷积层(50->25)
conv2=tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) #第三个卷积层(25->12)
conv3=tf.layers.conv2d(
inputs=pool2,
filters=128,
kernel_size=[3, 3],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2) #第四个卷积层(12->6)
conv4=tf.layers.conv2d(
inputs=pool3,
filters=128,
kernel_size=[3, 3],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2) re1 = tf.reshape(pool4, [-1, 6 * 6 * 128]) #全连接层
dense1 = tf.layers.dense(inputs=re1,
units=1024,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1,
units=512,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2,
units=5,
activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss) #---------------------------网络结束---------------------------
#%%
#取出所有参与训练的参数
params=tf.trainable_variables()
print("Trainable variables:------------------------") #循环列出参数
for idx, v in enumerate(params):
print(" param {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name)) #%%
#读取图片
img=io.imread('d:/cat.jpg')
#resize成100*100
img=transform.resize(img,(100,100))
#三维变四维(100,100,3)-->(1,100,100,3)
img=img[np.newaxis,:,:,:]
img=np.asarray(img,np.float32)
sess=tf.Session()
sess.run(tf.global_variables_initializer()) #提取最后一个全连接层的参数 W和b
W=sess.run(params[26])
b=sess.run(params[27]) #提取第二个全连接层的输出值作为特征
fea=sess.run(dense2,feed_dict={x:img})
最后一条语句就是提取某层的数据输出作为特征。
注意:这个程序并没有经过训练,因此提取出的参数只是初始化的参数。
tensorflow 1.0 学习:参数和特征的提取的更多相关文章
- tensorflow 1.0 学习:用CNN进行图像分类
tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化. 任务:花卉分类 版本:tensorflow 1 ...
- tensorflow 1.0 学习:参数初始化(initializer)
CNN中最重要的就是参数了,包括W,b. 我们训练CNN的最终目的就是得到最好的参数,使得目标函数取得最小值.参数的初始化也同样重要,因此微调受到很多人的重视,那么tf提供了哪些初始化参数的方法呢,我 ...
- tensorflow 1.0 学习:十图详解tensorflow数据读取机制
本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...
- tensorflow 2.0 学习(四)
这次的mnist学习加入了测试集,看看学习的准确率,代码如下 # encoding: utf-8 import tensorflow as tf import matplotlib.pyplot as ...
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)
池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...
- tensorflow 1.0 学习:卷积层
在tf1.0中,对卷积层重新进行了封装,比原来版本的卷积层有了很大的简化. 一.旧版本(1.0以下)的卷积函数:tf.nn.conv2d conv2d( input, filter, strides, ...
- tensorflow 1.0 学习:模型的保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- Tensorflow 2.0 学习资源
我从换了新工作才开始学习使用Tensorflow,感觉实在太难用了,sess和graph对 新手很不友好,各种API混乱不堪,这些在tf2.0都有了重大改变,2.0大量使用keras的 api,初步使 ...
随机推荐
- echarts使用踩坑实录之气泡图
最近想做一个统计文章点击率,评论率和点赞率的功能,听说echarts可以轻易完成它,于是我就选择使用echarts,考虑到我做的模块上文章是没有分类的,所以我的统计是基于一个个点,这一看嘛,感觉散点图 ...
- Java:ConcurrentLinkedQueue的实现原理分析
本文是作者原创,首发于InfoQ:http://www.infoq.com/cn/articles/ConcurrentLinkedQueue 1. 引言 在并发编程中我们有时候需要使用线程安全 ...
- 编写shell脚本kill掉占用cpu超过90%以上的程序
由于集群用户经常会不懂如何提交作业,将作业直接运行到登录节点上,这样导致登录节点的cpu及内存占用很大,导致其他用户甚至无法登录.所以就想到了一种解决方法,写一个shell脚本,常驻登录节点,监控cp ...
- 使用dockerfile,创建gitblit镜像
1. 快速使用gitblit镜像 1.1 push 镜像 # docker pull /gitblit 1.2 查看下载的镜像 # docker images | grep "gitblit ...
- MySql技术内幕之MySQL入门(1)
目录 MySql技术内幕之MySQL入门(1) 安装 关于注释 执行SQL语句 关于命令大小写 创建数据库 查看表的信息 查看更加详细的信息 查看与给定模式相匹配的列 插入数据 利用insert添加行 ...
- SSM_CRUD新手练习(10)返回分页的JSON数据
我们完成了员工的分页查询,但是现在这种做法只能适应浏览器和服务器的交互模式,但在移动互联网时代,客户端不仅仅只有浏览器,还有安卓和IOS客户端.我们的解决方式是AJAX+JSON方式来实现平台无关性. ...
- Java变成思想--多线程
Executor :线程池 CatchedThreadPool:创建与所需数量相同的线程,在回收旧线程是停止创建新县城. FixedThreadPool:创建一定数量的线程,所有任务公用这些线程. S ...
- Koa 学习笔记
开始 就像官网上说的,一切框架都从一个"Hello World"开始,首先我们新建一个 package.json,内容尽量简单: { "name": " ...
- day23_雷神_crm-day2
# 俺滴第一个项目 CRM MdelForm 实现增删改查 1. ModelForm,重写 __init__ 方法,给所有字段添加 form-control 样式. 2. ModelForm,报错错误 ...
- arp脚本
1.什么是arp?arp可以解决什么问题? ARP:是地址解析协议 arp解决我们知道一个机器(主机或者路由器)的IP地址,需要找出其相应的硬件地址 2.编写ARP脚本,抓取对应主机的mac地址 1 ...