tensorflow 1.0 学习:参数和特征的提取
在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如:
#取出所有参与训练的参数
params=tf.trainable_variables()
print("Trainable variables:------------------------") #循环列出参数
for idx, v in enumerate(params):
print(" param {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name))
这里只能查看参数的shape和name,并没有具体的值。如果要查看参数具体的值的话,必须先初始化,即:
sess=tf.Session()
sess.run(tf.global_variables_initializer())
同理,我们也可以提取图片经过训练后的值。图片经过卷积后变成了特征,要提取这些特征,必须先把图片feed进去。
具体看实例:
# -*- coding: utf-8 -*-
"""
Created on Sat Jun 3 12:07:59 2017 @author: Administrator
""" import tensorflow as tf
from skimage import io,transform
import numpy as np #-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,100,100,3],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_') #第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
inputs=x,
filters=32,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) #第二个卷积层(50->25)
conv2=tf.layers.conv2d(
inputs=pool1,
filters=64,
kernel_size=[5, 5],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) #第三个卷积层(25->12)
conv3=tf.layers.conv2d(
inputs=pool2,
filters=128,
kernel_size=[3, 3],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2) #第四个卷积层(12->6)
conv4=tf.layers.conv2d(
inputs=pool3,
filters=128,
kernel_size=[3, 3],
padding="same",
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2) re1 = tf.reshape(pool4, [-1, 6 * 6 * 128]) #全连接层
dense1 = tf.layers.dense(inputs=re1,
units=1024,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1,
units=512,
activation=tf.nn.relu,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2,
units=5,
activation=None,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
kernel_regularizer=tf.nn.l2_loss) #---------------------------网络结束---------------------------
#%%
#取出所有参与训练的参数
params=tf.trainable_variables()
print("Trainable variables:------------------------") #循环列出参数
for idx, v in enumerate(params):
print(" param {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name)) #%%
#读取图片
img=io.imread('d:/cat.jpg')
#resize成100*100
img=transform.resize(img,(100,100))
#三维变四维(100,100,3)-->(1,100,100,3)
img=img[np.newaxis,:,:,:]
img=np.asarray(img,np.float32)
sess=tf.Session()
sess.run(tf.global_variables_initializer()) #提取最后一个全连接层的参数 W和b
W=sess.run(params[26])
b=sess.run(params[27]) #提取第二个全连接层的输出值作为特征
fea=sess.run(dense2,feed_dict={x:img})
最后一条语句就是提取某层的数据输出作为特征。
注意:这个程序并没有经过训练,因此提取出的参数只是初始化的参数。
tensorflow 1.0 学习:参数和特征的提取的更多相关文章
- tensorflow 1.0 学习:用CNN进行图像分类
tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化. 任务:花卉分类 版本:tensorflow 1 ...
- tensorflow 1.0 学习:参数初始化(initializer)
CNN中最重要的就是参数了,包括W,b. 我们训练CNN的最终目的就是得到最好的参数,使得目标函数取得最小值.参数的初始化也同样重要,因此微调受到很多人的重视,那么tf提供了哪些初始化参数的方法呢,我 ...
- tensorflow 1.0 学习:十图详解tensorflow数据读取机制
本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...
- tensorflow 2.0 学习(四)
这次的mnist学习加入了测试集,看看学习的准确率,代码如下 # encoding: utf-8 import tensorflow as tf import matplotlib.pyplot as ...
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)
池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...
- tensorflow 1.0 学习:卷积层
在tf1.0中,对卷积层重新进行了封装,比原来版本的卷积层有了很大的简化. 一.旧版本(1.0以下)的卷积函数:tf.nn.conv2d conv2d( input, filter, strides, ...
- tensorflow 1.0 学习:模型的保存与恢复
将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...
- Tensorflow 2.0 学习资源
我从换了新工作才开始学习使用Tensorflow,感觉实在太难用了,sess和graph对 新手很不友好,各种API混乱不堪,这些在tf2.0都有了重大改变,2.0大量使用keras的 api,初步使 ...
随机推荐
- Springboot & Mybatis 构建restful 服务
Springboot & Mybatis 构建restful 服务一 1 前置条件 jdk测试:java -version maven测试:命令行之行mvn -v eclipse及maven插 ...
- 爬取QQ音乐(讲解爬虫思路)
一.问题描述: 本次爬取的对象是QQmusic,为自己后面做django音乐网站的开发获取一些资源. 二.问题分析: 由于QQmusic和网易音乐的方式差不多,都是讲歌曲信息放入到播放界面播放,在其他 ...
- Atcoder Beginner Contest 115 D Christmas 模拟,递归 B
D - Christmas Time limit : 2sec / Memory limit : 1024MB Score : 400 points Problem Statement In some ...
- 关于echarts生成雷达图的一些参数介绍
export const industryFactorOption = { title: { text: '雷达图', textStyle: { color: 'rgba(221,221,221,1) ...
- java集成memcached、redis防止缓存穿透
下载相关jar,安装Memcached,安装教程:http://www.runoob.com/memcached/memcached-install.html spring配置memcached &l ...
- ext__给grid Panel设置绑定事件
使用面板来展示详情信息 1.创建一个面板 (双击添加) 2.给该面板设置itemid的值为:detailPanel 3.给面板设置模板 4.添加下面的内容 id:{id}</br> nam ...
- flask-cookie & session
Cookie @app.route('/') def hello_world(): name=request.cookies.get('Name') # 获取cookie resp = Respon ...
- centos jdk 配置及版本切换
一. 环境变量: /etc/profile JAVA_HOME=/usr/lib/jdk1.8.0_91JRE_HOME=/usr/lib/jdk1.8.0_91/jreCLASS_PATH=.:$J ...
- JavaScript自定义鼠标右键菜单
下面为JavaScript代码 window.onload = function () { //好友列表 var f = 0; //判断指定id的元素在页面中是否存在 if (document.get ...
- navibar记录
@import (reference) "kmc-common.less"; .kmc{ font-family: PingFangSC-Reguxlar; font-weight ...