deeplabv3+ demo测试图像分割
#直接复制本代码,存为.py文件, 在大概204行左右更换模型地址,在223左右更换图片路径,直接执行即可得出简单的分割效果
#!--*-- coding:utf-8 --*-- # Deeplab Demo import os
import tarfile from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import tempfile
from six.moves import urllib import tensorflow as tf class DeepLabModel(object):
"""
加载 DeepLab 模型;
推断 Inference.
"""
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513
FROZEN_GRAPH_NAME = 'frozen_inference_graph' def __init__(self, tarball_path):
"""
加载预训练模型
"""
self.graph = tf.Graph() graph_def = None
# Extract frozen graph from tar archive.
tar_file = tarfile.open(tarball_path)
for tar_info in tar_file.getmembers():
if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
file_handle = tar_file.extractfile(tar_info)
graph_def = tf.GraphDef.FromString(file_handle.read())
break tar_file.close() if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default():
tf.import_graph_def(graph_def, name='') self.sess = tf.Session(graph=self.graph) def run(self, image):
""" Args:
image: 转换为PIL.Image 类,不能直接用图片,原始图片 Returns:
resized_image: RGB image resized from original input image.
seg_map: Segmentation map of `resized_image`.
"""
width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
seg_map = batch_seg_map[0]
return resized_image, seg_map def create_pascal_label_colormap():
"""
Creates a label colormap used in PASCAL VOC segmentation benchmark. Returns:
A Colormap for visualizing segmentation results.
"""
colormap = np.zeros((256, 3), dtype=int)
ind = np.arange(256, dtype=int) for shift in reversed(range(8)):
for channel in range(3):
colormap[:, channel] |= ((ind >> channel) & 1) << shift
ind >>= 3 return colormap def label_to_color_image(label):
"""
Adds color defined by the dataset colormap to the label. Args:
label: A 2D array with integer type, storing the segmentation label. Returns:
result: A 2D array with floating type. The element of the array
is the color indexed by the corresponding element in the input label
to the PASCAL color map. Raises:
ValueError: If label is not of rank 2 or its value is larger than color
map maximum entry.
"""
if label.ndim != 2:
raise ValueError('Expect 2-D input label') colormap = create_pascal_label_colormap() if np.max(label) >= len(colormap):
raise ValueError('label value too large.') return colormap[label] def vis_segmentation(image, seg_map, imagefile):
"""可视化三种图像."""
plt.figure(figsize=(15, 5))
grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) plt.subplot(grid_spec[0])
plt.imshow(image)
plt.axis('off')
plt.title('input image') plt.subplot(grid_spec[1])
seg_image = label_to_color_image(seg_map).astype(np.uint8)
# seg_image = label_to_color_image(seg_map)
# seg_image.save('/str(ss)+imagefile')
plt.imshow(seg_image)
plt.savefig('./'+imagefile+'.png') plt.axis('off')
plt.title('segmentation map') plt.subplot(grid_spec[2])
plt.imshow(image)
plt.imshow(seg_image, alpha=0.7)
plt.axis('off')
plt.title('segmentation overlay') unique_labels = np.unique(seg_map)
ax = plt.subplot(grid_spec[3])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0)
plt.grid('off')
plt.show() ##
LABEL_NAMES = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ]) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) ## Tensorflow 提供的模型下载
MODEL_NAME = 'xception_coco_voctrainval'
# ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval'] _DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
_MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', } _TARBALL_NAME = 'deeplab_model.tar.gz' # model_dir = tempfile.mkdtemp()
model_dir = './'
# tf.gfile.MakeDirs(model_dir) #
download_path = os.path.join(model_dir, _TARBALL_NAME)
print('downloading model, this might take a while...')
# urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
print('download completed! loading DeepLab model...') # model_dir = '/‘ # download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
MODEL = DeepLabModel('./deeplab_model.tar.gz')
# MODEL = './deeplab_model.tar.gz'
print('model loaded successfully!') ##
def run_visualization(imagefile):
"""
DeepLab 语义分割,并可视化结果.
"""
# orignal_im = Image.open(imagefile)
# print(type(orignal_im))
# orignal_im.show()
print('running deeplab on image %s...' % imagefile)
resized_im, seg_map = MODEL.run(Image.open(imagefile)) vis_segmentation(resized_im, seg_map,imagefile) images_dir = './pictures'
images = sorted(os.listdir(images_dir))
print(images)
# img='205729y9fodss9ao6ol5921-150x150.jpg'
# img.show()
for imgfile in images:
# img.show()
run_visualization(os.path.join(images_dir, imgfile)) print('Done.')
所使用的是deeplab_model.tar.gz,也可以修改代码使用在标准数据集上预训练过的模型;代码在182行附近。
1.修改模型保存路径
2.修改图片路径
3.运行即可
参考自:https://www.aiuai.cn/aifarm252.html
deeplabv3+ demo测试图像分割的更多相关文章
- 中标麒麟6.0_ICE3.4.2编译+demo测试(CPP)
(菜鸟版)确保 gcc版本4.4.6(其他版本未测试),4.8不行 一.降级GCC到4.4.6 注意:gcc g++ c++命令都为4.4.6(可用gcc -v; g++ -v; c++ -v 命令查 ...
- VS2017 + QT5 + C++开发环境搭建和计算器Demo测试
非常有帮助的参考资料: https://blog.csdn.net/gaojixu/article/details/82185694 该参考文献的主要流程: (1)QT下载安装:从官网下载QT,并记 ...
- Java 银联支付官网demo测试及项目整合代码
注:原文来源与 < Java 银联支付官网demo测试及项目整合代码 > 银联支付(网关支付B2C) 一.测试官网demo a)下载官网开发包,导入eclipse等待修改(下载的开发包没 ...
- Dom捕捉事件和冒泡事件-原理与demo测试
先参考一下百度百科对冒泡事件流的解释: ----------不喜欢读文字的同学,可以直接看下面demo,传递顺序简单明了! http://baike.baidu.com/link?url=kaeJHT ...
- RocketMQ初探(二)之RocketMQ3.26版本搭建(含简单Demo测试案例)
作为一名程序猿,要敢于直面各种现实,脾气要好,心态要棒,纵使Bug虐我千百遍,我待它如初恋,方法也有千万种,一条路不行,换条路走走,方向对了,只要前行,总会上了罗马的道. Apache4.x最新版本既 ...
- 【转载】Scrapy安装及demo测试笔记
Scrapy安装及demo测试笔记 原创 2016年09月01日 16:34:00 标签: scrapy / python Scrapy安装及demo测试笔记 一.环境搭建 1. 安装scrapy ...
- Zookeeper+Dubbo环境搭建与Demo测试
环境准备: 1. zookeeper-3.4.14 (下载地址:http://archive.apache.org/dist/zookeeper/) 2. dubbo-0.2.0 (下载地址 ...
- red5研究(一):下载,工程建立、oflaDemo安装、demo测试
一.red5下载.添加工程到myeclipse 1,从官网上下载red51.01版本(我下载的是red51.0的版本),下载链接http://www.red5.org/downloads/red5/1 ...
- Axis2创建WebService服务端接口+SoupUI以及Client端demo测试调用
第一步:引入axis2相关jar包,如果是pom项目,直接在pom文件中引入依赖就好 <dependency> <groupId>org.apache.axis2</gr ...
随机推荐
- ExtJS6 根据Value设置单元格颜色
renderer : function(value, meta) { if(parseInt(value) > 0) { meta.style = ""; } else { ...
- [mvc] 简单的forms认证
1.在web.config的system.web节点增加authentication节点,定义如下: <system.web> <compilation debug="tr ...
- EditText小记
今天在编写样式的时候,需要设置数据输入为单行,但是 Android:singleLine=”true” 显示为已过期,提示使用 android:maxLines=“1” 代替,但是设置后却发现并没有效 ...
- 01Hadoop二次排序
我的目的: 示例: 2012,01,01,352011,12,23,-42012,01,01,432012,01,01,232011,12,23,52011,4,1,22011,4,1,56 结果: ...
- python中MetaClass的一些用法
元类在很多编程语言中都有这样的概念,我们都知道,类可以创建对象,类本身也是对象,既然是对象,那么它肯定也是被创造出来的,元类就专门用来创造类对象,于是,这就给我们提供了一种操纵或者监听类的能力. 平时 ...
- 解决vscode无法提示golang的问题
https://github.com/Microsoft/vscode-go/wiki/Go-with-VS-Code-FAQ-and-Troubleshooting Q: Auto-completi ...
- 微软消息队列-MicroSoft Message Queue(MSMQ)队列的C#使用
目录 定义的接口 接口实现 建立队列工厂 写入队列 获取消息 什么是MSMQ Message Queuing(MSMQ) 是微软开发的消息中间件,可应用于程序内部或程序之间的异步通信.主要的机制是:消 ...
- python 中 __init__方法
注意1,__init__并不相当于C#中的构造函数,执行它的时候,实例已构造出来了. class A(object): def __init__(self,name): self.name=name ...
- PHP 数组转XML 格式
function buildXml( $data, $wrap= 'xml' ){ $str = "<{$wrap}>"; if( is_array( $data ) ...
- zeppelin 一直报这个警告 也是醉了
用./zeppelin-daemon.sh start 启动zeppelin 一直报这个警告.. WARN [2017-03-23 19:11:34,461] ({qtp483422889-45} N ...