import tensorflow as tf

def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] print [str(i.name) for i in not_initialized_vars] # only for testing
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))

上述代码是用于初始化剩余未被初始化的变量的函数

需要注意的是,我们一般采用tf.global_variables_initializer()作为初始化op会覆盖原来通过saver.restore()方式加载的变量状态,因此,不可采用此方法。

另外,如果采用sess.run(tf.global_variables_initializer())在 saver.restore()之前,是不起作用的,原因未知,restore函数似乎能屏蔽掉global_variables_initializer()

的初始化效果。

选择性加载变量时可以采用scope进行隔离,提取出name:var这样的键值对的字典作为saver的加载根据。如下代码:

# stage_merged.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_stages_merged, network_my_unet, network_enhance_raw
import platform
from PIL import Image if platform.system() == 'Windows':
data_dir = 'D:/data/Sony/dataset/bbf-raw-selected/'
elif platform.system() == 'Linux':
data_dir = './dataset/bbf-raw-selected/'
else:
print('platform not supported!')
assert False os.environ["CUDA_VISIBLE_DEVICES"] = ""
checkpoint_dir = './model_stage_merged/'
result_dir = './out_stage_merged/'
log_dir = './log_stage_merged/'
learning_rate = 1e-4
epoch_bound = 20000
save_model_every_n_epoch = 10 if platform.system() == 'Windows':
output_every_n_steps = 1
else:
output_every_n_steps = 100 if platform.system() == 'Windows':
ckpt_enhance_raw = 'D:/model/enhance_raw/'
ckpt_raw2rgb = 'D:/model/raw2rgb-c1/'
else:
ckpt_enhance_raw = './model/enhance_raw/'
ckpt_raw2rgb = './model/raw2rgb-c1/' # BBF100-2
bbf_w = 4032
bbf_h = 3024 patch_w = 512
patch_h = 512 max_level = 1023
black_level = 64 patch_w = 512
patch_h = 512 # set up dataset
input_files = glob.glob(data_dir + '/*.dng')
input_files.sort() def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl) def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_ tf.reset_default_graph()
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input') with tf.variable_scope('enhance_raw', reuse=tf.AUTO_REUSE):
enhanced_raw = network_enhance_raw(in_im, patch_h, patch_w)
with tf.variable_scope('raw2rgb', reuse=tf.AUTO_REUSE):
gt_im = network_my_unet(enhanced_raw, patch_h, patch_w)
with tf.variable_scope('stage_merged', reuse=tf.AUTO_REUSE):
out_im = network_stages_merged(in_im, patch_h, patch_w) gt_im_cut = tf.minimum(tf.maximum(gt_im, 0.0), 1.0)
out_im_cut = tf.minimum(tf.maximum(out_im, 0.0), 1.0)
ssim_loss = 1 - tf.image.ssim_multiscale(gt_im_cut[0], out_im_cut[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(gt_im_cut - out_im_cut), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(gt_im_cut - out_im_cut), axis=-1))
G_loss = ssim_loss
# G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss) ########## LOADING MODELS #############
scope_ = 'enhance_raw'
enhance_raw_var_list = tf.global_variables(scope_)
enhance_raw_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in enhance_raw_var_list]
enhance_raw_map = dict()
for i in range(len(enhance_raw_var_names)):
enhance_raw_map[enhance_raw_var_names[i]] = enhance_raw_var_list[i] saver_enhance_raw = tf.train.Saver(var_list=enhance_raw_map)
ckpt = tf.train.get_checkpoint_state(ckpt_enhance_raw)
if ckpt:
saver_enhance_raw.restore(sess, ckpt.model_checkpoint_path)
print('loaded enhance_raw model: ' + ckpt.model_checkpoint_path)
else:
print('Error: failed to load enhance_raw model!')
#----------------------------------------
scope_ = 'raw2rgb'
raw2rgb_var_list = tf.global_variables(scope_)
raw2rgb_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in raw2rgb_var_list]
raw2rgb_map = dict()
for i in range(len(raw2rgb_var_names)):
raw2rgb_map[raw2rgb_var_names[i]] = raw2rgb_var_list[i] saver_raw2rgb = tf.train.Saver(var_list=raw2rgb_map)
ckpt = tf.train.get_checkpoint_state(ckpt_raw2rgb)
if ckpt:
saver_raw2rgb.restore(sess, ckpt.model_checkpoint_path)
print('loaded raw2rgb model: ' + ckpt.model_checkpoint_path)
else:
print('Error: failed to load raw2rgb model!')
assert False
#---------------------------------------- def initialize_uninitialized(sess):
global_vars = tf.global_variables()
bool_inits = sess.run([tf.is_variable_initialized(var) for var in global_vars])
uninit_vars = [v for (v, b) in zip(global_vars, bool_inits) if not b]
for v in uninit_vars:
print(str(v.name))
if len(uninit_vars):
sess.run(tf.variables_initializer(uninit_vars)) t_vars = tf.trainable_variables(scope='stage_merged')
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=t_vars) saver = tf.train.Saver(var_list=t_vars)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt.model_checkpoint_path)
print('loaded ' + ckpt.model_checkpoint_path)
else:
sess.run(tf.variables_initializer(var_list=t_vars))
initialize_uninitialized(sess)
#######################################
if not os.path.isdir(result_dir):
os.mkdir(result_dir) input_images = [None] * len(input_files)
g_loss = np.zeros([500, 1]) merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph) steps = 0
st = time.time() for epoch in range(0, epoch_bound):
for ind in np.random.permutation(len(input_images)):
steps += 1
if input_images[ind] is None:
input_images[ind] = np.expand_dims(pack_raw_bbf(input_files[ind]), axis=0) # random cropping
xx = np.random.randint(0, bbf_w // 2 - patch_w)
yy = np.random.randint(0, bbf_h // 2 - patch_h)
input_patch = np.float32(input_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (
max_level - black_level) # random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
if np.random.randint(2, size=1)[0] == 1: # random transpose
input_patch = np.transpose(input_patch, (0, 2, 1, 3)) summary, _, G_current, output, gt_im_ = sess.run(
[merged, G_opt, G_loss, out_im_cut, gt_im_cut],
feed_dict={
in_im: input_patch,
lr: learning_rate})
g_loss[steps % len(g_loss)] = G_current if steps % output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st) / output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
temp = np.concatenate(
(input_patch[0, :, :, :3],
gt_im_[0, 0:patch_h*2:2, 0:patch_w*2:2, :3],
output[0, 0:patch_h*2:2, 0:patch_w*2:2, :3]), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255) \
.save(result_dir + '/%d_%d.jpg' % (epoch, steps)) # clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind] = None if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')

Tensorflow选择性初始化图中的变量的更多相关文章

  1. AI学习---TensorFlow框架介绍[图+会话+张量+变量OP+API]

    TensorFlow的数据流图 TensorFlow的结构分析: 图 + 会话 TensorFlow = 构图阶段(数据与操作的执行步骤被描绘出一个图) + 执行图阶段(使用回话执行构建好的图中操作) ...

  2. Tensorflow替换静态图中的OP

    import tensorflow as tf import collections from tensorflow.core.framework import tensor_shape_pb2 # ...

  3. java初始化过程中成员变量

    package day01; class Base{ int j; //1.j=0 Base(){ add(1); //2.调用子类add()方法 System.out.println(j); //4 ...

  4. 2、Tensorflow中的变量

    2.Tensorflow中的变量注意:tf中使用 变量必须先初始化下面是一个使用变量的TF代码(含注释): # __author__ = "WSX" import tensorfl ...

  5. Tensorflow中的变量

    从初识tf开始,变量这个名词就一直都很重要,因为深度模型往往所要获得的就是通过参数和函数对某一或某些具体事物的抽象表达.而那些未知的数据需要通过学习而获得,在学习的过程中它们不断变化着,最终收敛达到较 ...

  6. TensorFlow中的变量和常量

    1.TensorFlow中的变量和常量介绍 TensorFlow中的变量: import tensorflow as tf state = tf.Variable(0,name='counter') ...

  7. 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量

    1.tf.Variable([[1, 2]])  # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...

  8. tensorflow 保存训练模型ckpt 查看ckpt文件中的变量名和对应值

    TensorFlow 模型保存与恢复 一个快速完整的教程,以保存和恢复Tensorflow模型. 在本教程中,我将会解释: TensorFlow模型是什么样的? 如何保存TensorFlow模型? 如 ...

  9. c++ 类与函数中static变量初始化问题(转)

    首先static变量只有一次初始化,不管在类中还是在函数中..有这样一个函数: void Foo() { ; // initialize std::cout << a; a++; } 里的 ...

随机推荐

  1. redis的 rdb 和 aof 持久化的区别

    aof,rdb是两种 redis持久化的机制.用于crash后,redis的恢复. rdb的特性如下: Code: fork一个进程,遍历hash table,利用copy on write,把整个d ...

  2. OxyPlot Controller OxyPlot控制器

    Default input bindings The default input bindings in the PlotController are: Action Gesture Pan* Rig ...

  3. Java EE开发技术课程第五周(Applet程序组件与AJAX技术)

    1.Applet程序组件 1.1.定义: Applet是采用Java编程语言编写的小应用程序,该程序可以包含在HTML(标准通用标记语言的一个应用)页中,与在页中包含图像的方式大致相同.含有Apple ...

  4. npm install --save 、--save-dev 、-D、-S 的区别

    备注:<=> 意为等价于: 1.npm install <=> npm i --save   <=> -S --save-dev  <=> -D npm ...

  5. Linux环境配置文件的理解

    百度百科: .bashrc这个文件主要保存个人的一些个性化设置,如命令别名.路径等.也即在同一个服务器上,只对某个用户的个性化设置相关. 示例: 编辑# User specific aliases a ...

  6. 基于zigbee协议的空中下载技术(OTA)

    首先镜像服务器的解释: 镜像服务器(Mirror server)与主服务器的服务内容都是一样的,只是放在一个不同的地方,分担主机的负载. 简单来说就是和照镜子似的,能看,但不是原版的.在网上内容完全相 ...

  7. ElasticSearch(八)Elasticsearch-head 连接不上Elasticsearch的原因和解决方案

    在上篇博文里ElasticSearch(七) Elasticsearch在Centos下搭建可视化服务中已经访问到了可视化界面.然后兴奋地进行了数据提交测试,提交啊,刷新啊,就是看不到数据变化,仔细一 ...

  8. Python Iterables Iterators Generators

    container 某些对象包含其它对象的引用,这个包含其它对象引用的对象叫容器.例如list可以包含int对象,或者由其它数据类型(或数据结构)的对象组成一个list. 对其他对象的引用是容器值的一 ...

  9. 课后作业机票,赌骰子游戏,switch的使用实例

    一,课后第三题机票 package com.bd22; import java.util.Scanner; public class AirTicket { public static void ma ...

  10. stm32库函数建工程和使用Keil自带库建工程有没有区别?发现了同样的程序在两种情况下keil自带库可以运行的情况,不知是什么原因

    我使用库函数建的工程(非Keil自带库),为了实现SPI对Si24r1芯片数据的读写,以验证stm32是否可以和si24r1能够正常通信,发现使用库函数建的工程程序不能通过,读出来的数据和写的数据不一 ...