import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image # Root directory of the project
ROOT_DIR = os.path.abspath("../../") # Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log #%matplotlib inline # Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH) iter_num=0

  

Configurations

class ShapesConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "shapes" # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 2
IMAGES_PER_GPU = 1 #这里我用了两个GPU # Number of classes (including background)
NUM_CLASSES = 1 + 1 # background + 1 shapes # Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 1080
IMAGE_MAX_DIM = 1920 # Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (8*6, 16*6, 32*6, 64*6, 128*6) # anchor side in pixels # Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 32 # Use a small epoch since the data is simple
STEPS_PER_EPOCH = 100 # use small validation steps since the epoch is small
VALIDATION_STEPS = 5 config = ShapesConfig()
config.display()

  Notebook Preference

def get_ax(rows=1, cols=1, size=8):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes. Change the default size attribute to control the size
of rendered images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax

  Dataset

class DrugDataset(utils.Dataset):

    #得到该图中有多少个实例(物体)
def get_obj_index(self, image):
n = np.max(image)
return n
#解析labelme中得到的yaml文件,从而得到mask每一层对应的实例标签
def from_yaml_get_class(self,image_id):
info=self.image_info[image_id]
with open(info['yaml_path']) as f:
temp=yaml.load(f.read())
labels=temp['label_names']
del labels[0]
return labels
#重新写draw_mask
def draw_mask(self, num_obj, mask, image):
info = self.image_info[image_id]
for index in range(num_obj):
for i in range(info['width']):
for j in range(info['height']):
at_pixel = image.getpixel((i, j))
if at_pixel == index + 1:
mask[j, i, index] =1
return mask
#重新写load_shapes,里面包含自己的自己的类别(我的是box、column、package、fruit四类)
#并在self.image_info信息中添加了path、mask_path 、yaml_path
def load_shapes(self, count, height, width, img_floder, mask_floder, imglist,dataset_root_path):
"""Generate the requested number of synthetic images.
count: number of images to generate.
height, width: the size of the generated images.
"""
# Add classes
self.add_class("shapes", 1, "box") for i in range(count):
filestr = imglist[i].split(".")[0]
filestr = filestr.split("_")[0]
mask_path = mask_floder + "/" + filestr + ".png"
yaml_path=dataset_root_path+filestr+"rgb_"+"_json/info.yaml"
self.add_image("shapes", image_id=i, path=img_floder + "/"+imglist[i],
width=width, height=height, mask_path=mask_path,yaml_path=yaml_path)
#重写load_mask
def load_mask(self, image_id):
"""Generate instance masks for shapes of the given image ID.
"""
global iter_num
info = self.image_info[image_id]
count = 1 # number of object
img = Image.open(info['mask_path'])
num_obj = self.get_obj_index(img)
mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
mask = self.draw_mask(num_obj, mask, img)
occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
for i in range(count - 2, -1, -1):
mask[:, :, i] = mask[:, :, i] * occlusion
occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
labels=[]
labels=self.from_yaml_get_class(image_id)
labels_form=[]
for i in range(len(labels)):
if labels[i].find("box")!=-1:
#print "box"
labels_form.append("box")
#elif labels[i].find("column")!=-1:
#print "column"
# labels_form.append("column")
#elif labels[i].find("package")!=-1:
#print "package"
# labels_form.append("package")
#elif labels[i].find("fruit")!=-1:
#print "fruit"
# labels_form.append("fruit")
class_ids = np.array([self.class_names.index(s) for s in labels_form])
return mask, class_ids.astype(np.int32)

  基础设置

#基础设置
dataset_root_path="/mnt/disk2/zhouqiang/Mask_RCNN/data/train_01_01/"
img_floder = dataset_root_path+"rgb"
mask_floder = dataset_root_path+"mask"
#yaml_floder = dataset_root_path
imglist = os.listdir(img_floder)
count = len(imglist)
width = 1920
height = 1080 #train与val数据集准备
dataset_train = DrugDataset()
dataset_train.load_shapes(count, 1080, 1920, img_floder, mask_floder, imglist,dataset_root_path)
dataset_train.prepare() dataset_val = DrugDataset()
dataset_val.load_shapes(count, 1080, 1920, img_floder, mask_floder, imglist,dataset_root_path)
dataset_val.prepare()

  Create Model

# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
model_dir=MODEL_DIR)

 

# Which weights to start with?
init_with = "coco" # imagenet, coco, or last if init_with == "imagenet":
model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
# Load weights trained on MS COCO, but skip layers that
# are different due to the different number of classes
# See README for instructions to download the COCO weights
model.load_weights(COCO_MODEL_PATH, by_name=True,
exclude=["mrcnn_class_logits", "mrcnn_bbox_fc",
"mrcnn_bbox", "mrcnn_mask"])
elif init_with == "last":
# Load the last model you trained and continue training
model.load_weights(model.find_last(), by_name=True)

  

# Fine tune all layers
# Passing layers="all" trains all layers. You can also
# pass a regular expression to select which layers to
# train by name pattern.
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=50,
layers="all")

  

 

MaskRCNN 奔跑自己的数据的更多相关文章

  1. labelme2coco问题:TypeError: Object of type 'int64' is not JSON serializable

    最近在做MaskRCNN 在自己的数据(labelme)转为COCOjson格式遇到问题:TypeError: Object of type 'int64' is not JSON serializa ...

  2. Android动画之仿美团加载数据等待时,小人奔跑进度动画对话框(附顺丰快递员奔跑效果)

    Android动画之仿美团加载数据等待时,小人奔跑进度动画对话框(附顺丰快递员奔跑效果) 首句依然是那句老话,你懂得! finddreams :(http://blog.csdn.net/finddr ...

  3. 奔跑的歌颂 diskgenius 找回了20G数据

    2.0同学家的电脑不慎重装系统,结果默认重新分区.其他倒没什么数据,就是几千张记录孩子成长的照片最为珍贵.为了找回数据,用U盘启动,使用Diskgenius全部找回,在此奔歌一下.

  4. mask-rcnn代码解读(四):rpn_feature_maps数据的处理

    此处模拟 rpn_feature_maps数据的处理,最终得到rpn_class_logits, rpn_class, rpn_bbox. 代码如下: import numpy as np'''层与层 ...

  5. MVC5 网站开发之三 数据存储层功能实现

    数据存储层在项目Ninesky.DataLibrary中实现,整个项目只有一个类Repository.   目录 奔跑吧,代码小哥! MVC5网站开发之一 总体概述 MVC5 网站开发之二 创建项目 ...

  6. C#工业物联网和集成系统解决方案的技术路线(数据源、数据采集、数据上传与接收、ActiveMQ、Mongodb、WebApi、手机App)

    目       录 工业物联网和集成系统解决方案的技术路线... 1 前言... 1 第一章           系统架构... 3 1.1           硬件构架图... 3 1.2      ...

  7. jquery表格动态增删改及取数据绑定数据完整方案

    一 前言 上一篇Jquery遮罩插件,想罩哪就罩哪! 结尾的预告终于来了. 近期参与了一个针对内部员工个人信息收集的系统,其中有一个需求是在填写各个相关信息时,需要能动态的增加行当时公司有自己的解决方 ...

  8. 奔跑的xiaodao

    http://acm.hrbust.edu.cn/index.php?m=ProblemSet&a=showProblem&problem_id=2086 很明显的一个二分题目.因为要 ...

  9. 奔跑在Docker上的Spark

    转自:马踏飞燕--奔跑在Docker上的Spark 目录 为什么要在Docker上搭建Spark集群 网络拓扑 Docker安装及配置 ssh安装及配置 基础环境安装 Zookeeper安装及配置 H ...

随机推荐

  1. os 模块常用方法

    os.remove()删除文件 os.rename()重命名文件 os.walk()生成目录树下的所有文件名 os.chdir()改变目录 os.mkdir/makedirs创建目录/多层目录 os. ...

  2. [luogu 5024] 保卫王国

    Problem Here Solution 这大概是一篇重复累赘的blog吧. 最小权覆盖集=全集-最大权独立集 强制取或不取,可以通过将权值修改成inf或者-inf 然后就用动态dp的套路就行了 动 ...

  3. 【洛谷】P3518 [POI2011]SEJ-Strongbox

    题目描述 有一个密码箱,0到n-1中的某些整数是它的密码. 且满足,如果a和b都是它的密码,那么(a+b)%n也是它的密码(a,b可以相等) 某人试了k次密码,前k-1次都失败了,最后一次成功了. 问 ...

  4. Spark(二)—— 标签计算、用户画像应用

    一.标签计算 数据 86913510 {"reviewPics":[],"extInfoList":null,"expenseList":n ...

  5. 利用JS-SDK微信分享接口调用(后端.NET)

    一直都想研究一下JS-SDK微信分享的接口调用,由于最近工作需要,研究了一下,目前只是实现了部分接口的调用:其他接口调用也是类似的: 在开发之前,需要提前准备一个微信公众号,并且域名JSAPI 配置接 ...

  6. Linux下用jar命令替换war包中的文件【转】

    问题背景:在Linux环境上的weblogic发布war包,有时候只是修改了几个文件,也要上传整个war包,这样很费时间,因此整理了一下Linux环境,更新单个文件的方法. 1.如果要替换的文件直接在 ...

  7. python-pptx

    python-pptx的使用首先需要了解几个基本概念: 1.引入python-pptx frompptximportpresentation    # 实例化Presentation    prs= ...

  8. .gitignore忽略多层文件夹用**

    一.写法 **/bin/Debug/ 前面的两个*号代表任意多层上级文件夹 需要 git 1.8.2 及其以上的版本才支持 如何查看当前版本并且升级(windows) 二.如何升级 git是2.14. ...

  9. RabbitMQ教程C#版 - 发布订阅

    先决条件本教程假定 RabbitMQ 已经安装,并运行在localhost标准端口(5672).如果你使用不同的主机.端口或证书,则需要调整连接设置. 从哪里获得帮助如果您在阅读本教程时遇到困难,可以 ...

  10. 在linux的用户空间操作gpio

    1. 使能linux内核选项CONFIG_GPIO_SYSFS CONFIG_GPIO_SYSFS=y 2. 测试方法 2.1 关注/sys/class/gpio下的文件 --export/unexp ...