采用Tensorflow内部函数直接对模型进行冻结
# enhance_raw.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_enhance_raw
import platform
import os from tensorflow.python.tools import freeze_graph os.environ["CUDA_VISIBLE_DEVICES"] = "" if platform.system() == 'Windows':
data_dir = 'D:/data/LightOnOff/'
elif platform.system() == 'Linux':
data_dir = './dataset/LightOnOff/'
else:
print('platform not supported!')
assert False checkpoint_dir = './model_light_on_off/'
result_dir = './out_light_on_off/'
log_dir = './log_light_on_off/'
learning_rate = 1e-4
save_model_every_n_epoch = 10
max_epoch = 20000
if platform.system() == 'Windows':
save_output_every_n_steps = 1
else:
save_output_every_n_steps = 100 # BBF100-2
bbf_w = 4032
bbf_h = 3024 patch_h = 512
patch_w = 512 patch_h = 800
patch_w = 1024 max_level = 1023
black_level = 64 tf.reset_default_graph() # set up dataset
train_ids = os.listdir(data_dir)
train_ids.sort() def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl) def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0) # normalize the brightness
# out = np.minimum(out * 0.2 / np.maximum(1e-6, np.mean(out[:, :, 1])), 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_ def raw2rgb(raw): # GRBG
assert len(raw.shape)==3
h, w = raw.shape[0]<<1, raw.shape[1]<<1
rgb = np.zeros([h, w, 3])
rgb[0:h:2, 0:w:2, 1] = raw[:, :, 1]
rgb[0:h:2, 1:w:2, 0] = raw[:, :, 0]
rgb[1:h:2, 0:w:2, 2] = raw[:, :, 2]
rgb[1:h:2, 1:w:2, 1] = raw[:, :, 3]
return rgb def max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center):
return np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(
np.maximum(left, left_top),
top),
top_right),
right),
right_bottom),
bottom),
bottom_left),
center) def demosaic(rgb):
for chn_id in range(3):
left = rgb[0:-2, 1:-1, chn_id]
left_top = rgb[0:-2, 0:-2, chn_id]
top = rgb[0:-2, 1:-1, chn_id]
top_right = rgb[0:-2, 2:, chn_id]
right = rgb[1:-1, 2:, chn_id]
right_bottom = rgb[2:, 2:, chn_id]
bottom = rgb[2:, 1:-1, chn_id]
bottom_left = rgb[2:, 0:-2, chn_id]
center = rgb[1:-1, 1:-1, chn_id]
rgb[1:-1, 1:-1, chn_id] = max_in_all(left, left_top, top, top_right, right, right_bottom, bottom, bottom_left, center)
return rgb def gray_ps(rgb):
return np.power(np.power(rgb[:, :, 0], 2.2) * 0.2973 + np.power(rgb[:,:,1], 2.2) * 0.6274 + np.power(rgb[:,:,2], 2.2) * 0.0753, 1/2.2) + 1e-7 def gamma_correction(x, curve_ratio):
gray_scale = np.expand_dims(gray_ps(x), axis=-1)
gray_scale_new = np.power(gray_scale, curve_ratio)
return np.minimum(x * gray_scale_new / gray_scale, 1.0) # setting the ratio of GPU global memory usage
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input')
gt_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4])
out_im = network_enhance_raw(in_im, patch_h, patch_w)
norm_im = tf.minimum(tf.maximum(out_im, 0.0), 1.0) ssim_loss = 1 - tf.image.ssim_multiscale(norm_im[0], gt_im[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(norm_im - gt_im), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(norm_im - gt_im), axis=-1))
# G_loss = ssim_loss
G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('MS-SSIM Loss', ssim_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss) t_vars = tf.trainable_variables()
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
print('loaded ' + ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path) # save the images for tracking training states
if not os.path.isdir(result_dir):
os.mkdir(result_dir) g_loss = np.zeros((500, 1)) merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph) gt_files = [None] * len(train_ids)
input_files = [None] * len(train_ids) input_images = [None] * len(train_ids)
gt_images = [None] * len(train_ids) for i in range(0, len(train_ids)):
gt_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*on*.dng')[0]
input_files[i] = glob.glob(os.path.join(data_dir, train_ids[i]) + '/*off*.dng')
input_images[i] = [None] * len(input_files[i]) steps = 0
st = time.time() for epoch in range(0, max_epoch):
for ind in np.random.permutation(len(train_ids)):
steps += 1
sid = np.random.randint(0, len(input_files[ind]))
if input_images[ind][sid] is None:
input_images[ind][sid] = np.expand_dims(pack_raw_bbf(input_files[ind][sid]), axis=0)
if gt_images[ind] is None:
gt_images[ind] = np.expand_dims(np.maximum(pack_raw_bbf(gt_files[ind]), 0), axis=0) # random cropping
xx = np.random.randint(0, bbf_w//2 - patch_w)
yy = np.random.randint(0, bbf_h//2 - patch_h)
input_patch = np.float32(input_images[ind][sid][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level)
gt_patch = np.float32(gt_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (max_level - black_level) # random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
gt_patch = np.flip(gt_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
gt_patch = np.flip(gt_patch, axis=0)
# if np.random.randint(2, size=1)[0] == 1: # random transpose
# input_patch = np.transpose(input_patch, (0, 2, 1, 3))
# gt_patch = np.transpose(gt_patch, (0, 2, 1, 3)) # summary, _, G_current, output = sess.run(
# [merged, G_opt, G_loss, out_im],
# feed_dict={
# in_im: input_patch,
# gt_im: gt_patch,
# lr: learning_rate})
# g_loss[ind] = G_current summary, output = sess.run(
[merged, out_im],
feed_dict={
in_im: input_patch,
gt_im: gt_patch,
lr: learning_rate
}) # saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
# print('model saved.')
# exit(0) tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model_raw2raw.pb')
freeze_graph.freeze_graph(
'output_model/pb_model/model_raw2raw.pb',
'',
False,
'./model_light_on_off/0.ckpt',
'gen/output',
'save/restore_all',
'save/Const:0',
'output_model/pb_model/frozen_model.pb',
True,
"")
exit(0) if steps % save_output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st)/save_output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
# save the current output image for network inspection
out_ = np.minimum(np.maximum(output, 0), 1)
in_rgb = gamma_correction(demosaic(raw2rgb(input_patch[0])), 0.35)
gt_rgb = gamma_correction(demosaic(raw2rgb(gt_patch[0])), 0.35)
out_rgb = gamma_correction(demosaic(raw2rgb(out_[0])), 0.35)
temp = np.concatenate((in_rgb, gt_rgb, out_rgb), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255)\
.save(result_dir + '/%d_%s_00.jpg' % (epoch, train_ids[ind])) # clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind][sid] = None
gt_images[ind] = None if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')
采用Tensorflow内部函数直接对模型进行冻结的更多相关文章
- tensorflow加载embedding模型进行可视化
1.功能 采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量 ps:采用C++版本训练的w2v模型,python的gensi ...
- TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model
TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model Checkmate is designed to be a simple drop-i ...
- tensorflow训练验证码识别模型
tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...
- 开园第一篇---有关tensorflow加载不同模型的问题
写在前面 今天刚刚开通博客,主要想法跟之前某位博主说的一样,希望通过博客园把每天努力的点滴记录下来,也算一种坚持的动力.我是小白一枚,有啥问题欢迎各位大神指教,鞠躬~~ 换了新工作,目前手头是OCR项 ...
- 【6】TensorFlow光速入门-python模型转换为tfjs模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862365.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 【4】TensorFlow光速入门-保存模型及加载模型并使用
本文地址:https://www.cnblogs.com/tujia/p/13862360.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- 【TensorFlow】基于ssd_mobilenet模型实现目标检测
最近工作的项目使用了TensorFlow中的目标检测技术,通过训练自己的样本集得到模型来识别游戏中的物体,在这里总结下. 本文介绍在Windows系统下,使用TensorFlow的object det ...
- TensorFlow学习笔记12-word2vec模型
为什么学习word2word2vec模型? 该模型用来学习文字的向量表示.图像和音频可以直接处理原始像素点和音频中功率谱密度的强度值, 把它们直接编码成向量数据集.但在"自然语言处理&quo ...
- tensorflow之逻辑回归模型实现
前面一篇介绍了用tensorflow实现线性回归模型预测sklearn内置的波士顿房价,现在这一篇就记一下用逻辑回归分类sklearn提供的乳腺癌数据集,该数据集有569个样本,每个样本有30维,为二 ...
随机推荐
- 关于linux系统CPU篇--->不容易发现的占用CPU较高进程
1.系统的CPU使用率,不仅包括进程用户态和内核态的运行,还包括中断处理,等待IO以及内核线程等等.所以,当你发现系统的CPU使用率很高的时候,不一定能找到相对应的高CPU使用率的进程 2.案例分析, ...
- springboot+thymeleaf+pageHelper带条件分页查询
html层 <div> <a class="num"><b th:text="'共 '+ ${result.resultMap['pages ...
- Yii1操作phpexcel
Yii::import('application.vendors.phpexcel.*'); Yii::import('application.vendors.phpexcel.PHPExcel.*' ...
- 全网搜歌神器Listen1 Mac中文版
listen1 for mac中文版是mac上一款强大的全网搜歌音乐播放器,支持网易云音乐.QQ音乐.虾米音乐.酷狗音乐以及酷我音乐等网站的歌曲搜索播放功能,拥有创建歌单.随心播放.歌曲收藏.快速搜索 ...
- 配置IPV6地址
题:在考试系统上设定接口eth0使用下列IPV6地址: system1上的地址应该是2003:ac18::305/64 system2上的地址应该是2003:ac18::30a/64 两个系统必须能与 ...
- mime类型的解析与应用
MIME类型解析 MIME(Multipurpose Internet Mail Extensions)多用途网络邮件扩展类型,可被称为Media type或Content type, 它设定某种 ...
- 【Alpha】Scrum Meeting 10
目录 前言 任务分配 燃尽图 会议照片 签入记录 困难 前言 第10次会议于4月14日19:00在教一316召开. 交流确认了任务进度,对下一阶段任务进行分配.时长40min. 任务分配 姓名 当前阶 ...
- 【Alpha】Scrum Meeting 7
目录 前言 任务分配 燃尽图 会议照片 签入记录 困难 前言 第7次会议在4月11日19:00由PM在教一317召开. 交流确认了任务进度,对下一阶段任务进行分配.时长60min. 任务分配 姓名 当 ...
- 【读书笔记】使用JMeter创建数据库(Mysql)测试
读书笔记:<零成本实现Web性能测试>第4章 记得某天按照虫师博客的写的,折腾后成功了.今天又忘记了... 折腾后又成功了,赶紧记录下... 原文:http://www.cnblogs.c ...
- 在ubuntu的bash中循环执行脚本,并在内存不足时重启
#!/bin/bash date ma=`grep MemAvailable /proc/meminfo | awk '{print $2}'` echo MemAvailable = $ma run ...