用MMCls训练手势模型
- import os
- import json
- import mmcv
- import time
- from mmcv import Config
- from mmdet.apis import inference_detector, init_detector, show_result_pyplot, train_detector
- from mmdet.models import build_detector
- from mmdet.datasets import build_dataset
- from mmcv.runner import load_checkpoint
- from tqdm import tqdm
- import warnings
- import torch
- import numpy as np
- import cv2
- warnings.filterwarnings("ignore")
- class MMDetection:
- def sota():
- pypath = os.path.abspath(__file__)
- father = os.path.dirname(pypath)
- models = os.path.join(father, 'models')
- sota_model = []
- for i in os.listdir(models):
- if i[0] != '_':
- sota_model.append(i)
- return sota_model
- def __init__(self,
- backbone='FasterRCNN',
- num_classes=-1,
- dataset_path=None
- ):
- # 获取外部运行py的绝对路径
- self.cwd = os.path.dirname(os.getcwd())
- # 获取当前文件的绝对路径
- self.file_dirname = os.path.dirname(os.path.abspath(__file__))
- self.save_fold = None
- self.is_sample = False
- self.config = os.path.join(
- self.file_dirname, 'models', 'FasterRCNN/FasterRCNN.py')
- self.checkpoint = os.path.join(
- self.file_dirname, 'models', '/FasterRCNN/FasterRCNN.pth')
- self.backbone = backbone
- backbone_path = os.path.join(
- self.file_dirname, 'models', self.backbone)
- ckpt_cfg_list = list(os.listdir(backbone_path))
- for item in ckpt_cfg_list:
- if item[-1] == 'y' and item[0] != '_': # pip包修改1
- self.config = os.path.join(backbone_path, item)
- elif item[-1] == 'h':
- self.checkpoint = os.path.join(backbone_path, item)
- else:
- # print("Warning!!! There is an unrecognized file in the backbone folder.")
- pass
- self.cfg = Config.fromfile(self.config)
- self.dataset_path = dataset_path
- self.lr = None
- self.backbonedict = {
- "FasterRCNN": os.path.join(self.file_dirname, 'models', 'FasterRCNN/FasterRCNN.py'),
- "Yolov3": os.path.join(self.file_dirname, 'models', 'Yolov3/Yolov3.py'),
- "SSD_Lite":os.path.join(self.file_dirname, 'models', 'SSD_Lite/SSD_Lite.py'),
- # "Mask_RCNN":os.path.join(self.file_dirname, 'models', 'Mask_RCNN/Mask_RCNN.py'),
- # 下略
- }
- self.num_classes = num_classes
- self.chinese_res = None
- self.is_sample = False
- def train(self, random_seed=0, save_fold=None, distributed=False, validate=True, device='cpu',
- metric='bbox', save_best='bbox_mAP', optimizer="SGD", epochs=100, lr=0.001, weight_decay=0.001,
- Frozen_stages=1,
- checkpoint=None, batch_size=None):
- # 加载网络模型的配置文件
- self.cfg = Config.fromfile(self.backbonedict[self.backbone])
- # 如果外部不指定save_fold
- if not self.save_fold:
- # 如果外部也没有传入save_fold,我们使用默认路径
- if not save_fold:
- self.save_fold = os.path.join(
- self.cwd, 'checkpoints/det_model')
- # 如果外部传入save_fold,我们使用传入值
- else:
- self.save_fold = save_fold
- self.cfg.model.backbone.frozen_stages = Frozen_stages
- if self.num_classes != -1:
- if "RCNN" not in self.backbone: # 单阶段
- self.cfg.model.bbox_head.num_classes =self.num_classes
- elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- elif self.backbone == "Mask_RCNN":
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
- self.load_dataset(self.dataset_path)
- # 添加需要进行检测的类名
- if self.backbone in ["Yolov3"]:
- self.cfg.classes = self.get_classes(self.cfg.data.train.dataset.ann_file)
- else:
- self.cfg.classes = self.get_classes(self.cfg.data.train.ann_file)
- # 分别为训练、测试、验证添加类名
- if self.backbone in ["Yolov3"]:
- self.cfg.data.train.dataset.classes = self.cfg.classes
- else:
- self.cfg.data.train.classes = self.cfg.classes
- self.cfg.data.test.classes = self.cfg.classes
- self.cfg.data.val.classes = self.cfg.classes
- # 进行
- self.cfg.work_dir = self.save_fold
- # 创建工作目录
- mmcv.mkdir_or_exist(os.path.abspath(self.cfg.work_dir))
- # 创建分类器
- datasets = [build_dataset(self.cfg.data.train)]
- model = build_detector(self.cfg.model, train_cfg=self.cfg.get(
- 'train_cfg'), test_cfg=self.cfg.get('test_cfg'))
- # print("checkpoint", checkpoint)
- if not checkpoint:
- model.init_weights()
- else:
- checkpoint = os.path.abspath(checkpoint) # pip修改2
- load_checkpoint(model, checkpoint, map_location=torch.device(device))
- model.CLASSES = self.cfg.classes
- if optimizer == 'Adam':
- self.cfg.optimizer = dict(type='Adam', lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0001)
- elif optimizer == 'Adagrad':
- self.cfg.optimizer = dict(type='Adagrad', lr=lr, lr_decay=0)
- # 根据输入参数更新config文件
- self.cfg.optimizer.lr = lr # 学习率
- self.cfg.optimizer.type = optimizer # 优化器
- self.cfg.optimizer.weight_decay = weight_decay # 优化器的衰减权重
- self.cfg.evaluation.metric = metric # 验证指标
- # self.cfg.evaluation.save_best = save_best
- self.cfg.runner.max_epochs = epochs # 最大的训练轮次
- # 设置每 5 个训练批次输出一次日志
- # self.cfg.log_config.interval = 1
- self.cfg.gpu_ids = range(1)
- self.cfg.seed = random_seed
- if batch_size is not None:
- self.cfg.data.samples_per_gpu = batch_size
- train_detector(
- model,
- datasets,
- self.cfg,
- distributed=distributed,
- validate=validate,
- timestamp=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
- meta=dict()
- )
- def print_result(self, res=None):
- if self.is_sample == True:
- print("示例检测结果如下:")
- sample_result = r"[{'类别标签': 0, '置信度': 1.0, '坐标': {'x': 26, 'y': 81, 'w': 497, 'h': 414}},{'类别标签': 1, '置信度': 1.0, '坐标': {'x1': 250, 'y1': 103, 'x2': 494, 'y2': 341}}]"
- print(sample_result)
- else:
- print("检测结果如下:")
- print(self.chinese_res)
- return self.chinese_res
- def load_checkpoint(self, checkpoint=None, device='cpu',
- rpn_threshold=0.7, rcnn_threshold=0.7):
- print("========= begin inference ==========")
- if self.num_classes != -1 and self.backbone not in ["Yolov3", "SSD_Lite"]:
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- if checkpoint:
- # 加载数据集及配置文件的路径
- # self.load_dataset(self.dataset_path)
- # 修正检测的目标
- # self.cfg.classes = self.get_class(class_path)
- self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
- self.cfg.data.train.classes = self.cfg.classes
- self.cfg.data.test.classes = self.cfg.classes
- self.cfg.data.val.classes = self.cfg.classes
- if "RCNN" not in self.backbone: # 单阶段
- self.cfg.model.bbox_head.num_classes = len(self.cfg.classes)
- else: # rcnn系列 双阶段
- self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
- # self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
- self.infer_model = init_detector(self.cfg, checkpoint, device=device)
- self.infer_model.CLASSES = self.cfg.classes
- else:
- self.infer_model = init_detector(self.cfg, self.checkpoint, device=device)
- if self.backbone not in ["Yolov3", "SSD_Lite"]: self.infer_model.test_cfg.rpn.nms.iou_threshold = 1 - rpn_threshold
- if self.backbone not in ["Yolov3", "SSD_Lite"]: self.infer_model.test_cfg.rcnn.nms.iou_threshold = 1 - rcnn_threshold
- def fast_inference(self, image, show=False, save_fold='det_result'):
- img_array = mmcv.imread(image)
- try:
- self.infer_model
- except:
- print("请先使用load_checkpoint()方法加载权重!")
- return
- result = inference_detector(self.infer_model, img_array) # 此处的model和外面的无关,纯局部变量
- self.infer_model.show_result(image, result, show=show,
- out_file=os.path.join(save_fold, os.path.split(image)[1]))
- chinese_res = []
- for i in range(len(result)):
- for j in range(result[i].shape[0]):
- tmp = {}
- tmp['类别标签'] = i
- tmp['置信度'] = result[i][j][4]
- tmp['坐标'] = {"x": int(result[i][j][0]), "y": int(
- result[i][j][1]), 'w': int(result[i][j][2]), 'h': int(result[i][j][3])}
- # img.append(tmp)
- chinese_res.append(tmp)
- # print(chinese_res)
- self.chinese_res = chinese_res
- # print("========= finish inference ==========")
- return result
- def inference(self, device='cpu',
- checkpoint=None,
- image=None,
- show=True,
- rpn_threshold=0.7,
- rcnn_threshold=0.7,
- save_fold='det_result',
- ):
- # self.cfg.classes = self.get_class(class_path)
- # self.num_classes = len(self.get_class(class_path))
- # if self.num_classes != -1:
- # if "RCNN" not in self.backbone: # 单阶段
- # self.cfg.model.bbox_head.num_classes =self.num_classes
- # elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
- # self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- # elif self.backbone == "Mask_RCNN":
- # self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- # self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
- if image == None:
- self.is_sample = True
- sample_return = """
- [array([[ 26.547777 , 81.55447 , 497.37015 , 414.4934 ,
- 1.0]], dtype=float32),
- array([[2.5098564e+02, 1.0334784e+02, 4.9422855e+02, 3.4187744e+02,
- 1.0], dtype=float32)]
- """
- return sample_return
- self.is_sample = False
- print("========= begin inference ==========")
- if self.num_classes != -1 and self.backbone not in ["Yolov3", "SSD_Lite"] :
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- if checkpoint:
- # 加载数据集及配置文件的路径
- # self.load_dataset(self.dataset_path)
- # 修正检测的目标
- self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
- self.num_classes = len(self.cfg.classes)
- if self.num_classes != -1:
- if "RCNN" not in self.backbone: # 单阶段
- self.cfg.model.bbox_head.num_classes =self.num_classes
- elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- elif self.backbone == "Mask_RCNN":
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
- self.cfg.data.train.classes = self.cfg.classes
- self.cfg.data.test.classes = self.cfg.classes
- self.cfg.data.val.classes = self.cfg.classes
- if self.backbone not in ["Yolov3", "SSD_Lite"]:
- self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
- model = init_detector(self.cfg, checkpoint, device=device)
- model.CLASSES = self.cfg.classes
- else:
- model = init_detector(self.cfg, self.checkpoint, device=device)
- # model = build_detector(self.cfg.model, train_cfg=self.cfg.get(
- # 'train_cfg'), test_cfg=self.cfg.get('test_cfg'))
- # if not checkpoint:
- # model.init_weights()
- # else:
- # checkpoint = os.path.abspath(checkpoint) # pip修改2
- # load_checkpoint(model, checkpoint, map_location=torch.device(device))
- if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rpn.nms.iou_threshold = 1 - rpn_threshold
- if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rcnn.nms.iou_threshold = 1 - rcnn_threshold
- results = []
- if (image[-1] != '/'):
- img_array = mmcv.imread(image)
- result = inference_detector(
- model, img_array) # 此处的model和外面的无关,纯局部变量
- if show == True:
- show_result_pyplot(model, image, result)
- model.show_result(image, result, show=show, out_file=os.path.join(save_fold, os.path.split(image)[1]))
- chinese_res = []
- for i in range(len(result)):
- for j in range(result[i].shape[0]):
- tmp = {}
- tmp['类别标签'] = i
- tmp['置信度'] = result[i][j][4]
- tmp['坐标'] = {"x1": int(result[i][j][0]), "y1": int(
- result[i][j][1]), 'x2': int(result[i][j][2]), 'y2': int(result[i][j][3])}
- # img.append(tmp)
- chinese_res.append(tmp)
- # print(chinese_res)
- self.chinese_res = chinese_res
- print("========= finish inference ==========")
- return result
- else:
- img_dir = image
- mmcv.mkdir_or_exist(os.path.abspath(save_fold))
- chinese_results = []
- for i, img in enumerate(tqdm(os.listdir(img_dir))):
- result = inference_detector(
- model, img_dir + img) # 此处的model和外面的无关,纯局部变量
- model.show_result(img_dir + img, result,
- out_file=os.path.join(save_fold, img))
- chinese_res = []
- for i in range(len(result)):
- for j in range(result[i].shape[0]):
- tmp = {}
- tmp['类别标签'] = i
- tmp['置信度'] = result[i][j][4]
- tmp['坐标'] = {"x1": int(result[i][j][0]), "y1": int(
- result[i][j][1]), 'x2': int(result[i][j][2]), 'y2': int(result[i][j][3])}
- # img.append(tmp)
- chinese_res.append(tmp)
- chinese_results.append(chinese_res)
- results.append(result)
- self.chinese_res = chinese_results
- print("========= finish inference ==========")
- return results
- def load_dataset(self, path):
- self.dataset_path = path
- # 数据集修正为coco格式
- if self.backbone in ["Yolov3"]:
- self.cfg.data.train.dataset.img_prefix = os.path.join(self.dataset_path, 'images/train/')
- self.cfg.data.train.dataset.ann_file = os.path.join(self.dataset_path, 'annotations/train.json')
- else:
- self.cfg.data.train.img_prefix = os.path.join(self.dataset_path, 'images/train/')
- self.cfg.data.train.ann_file = os.path.join(self.dataset_path, 'annotations/train.json')
- self.cfg.data.val.img_prefix = os.path.join(self.dataset_path, 'images/test/')
- self.cfg.data.val.ann_file = os.path.join(self.dataset_path, 'annotations/valid.json')
- self.cfg.data.test.img_prefix = os.path.join(self.dataset_path, 'images/test/')
- self.cfg.data.test.ann_file = os.path.join(self.dataset_path, 'annotations/valid.json')
- def get_class(self, class_path):
- classes = []
- with open(class_path, 'r') as f:
- for name in f:
- classes.append(name.strip('\n'))
- return classes
- def get_classes(self, annotation_file):
- classes = ()
- with open(annotation_file, 'r') as f:
- dataset = json.load(f)
- # categories = dataset["categories"]
- if 'categories' in dataset:
- for cat in dataset['categories']:
- classes = classes + (cat['name'],)
- return classes
- def convert(self, checkpoint=None, backend="ONNX", out_file="convert_model.onnx",device='cpu'):
- import os.path as osp
- from mmdet.core.export import build_model_from_cfg
- ashape = self.cfg.test_pipeline[1].img_scale
- if len(ashape) == 1:
- input_shape = (1, 3, ashape[0], ashape[0])
- elif len(ashape) == 2:
- input_shape = (
- 1,
- 3,
- ) + tuple(ashape)
- else:
- raise ValueError('invalid input shape')
- self.cfg.model.pretrained = None
- if self.backbone not in ["Yolov3", "SSD_Lite"] :
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- else:
- self.cfg.model.bbox_head.num_classes = self.num_classes
- # build the model and load checkpoint
- # detector = build_detector(self.cfg.model)
- # model = build_model_from_cfg(self.config, checkpoint)
- if checkpoint:
- # 加载数据集及配置文件的路径
- # self.load_dataset(self.dataset_path)
- # 修正检测的目标
- self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
- self.num_classes = len(self.cfg.classes)
- if self.num_classes != -1:
- if "RCNN" not in self.backbone: # 单阶段
- self.cfg.model.bbox_head.num_classes =self.num_classes
- elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- elif self.backbone == "Mask_RCNN":
- self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
- self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
- self.cfg.data.train.classes = self.cfg.classes
- self.cfg.data.test.classes = self.cfg.classes
- self.cfg.data.val.classes = self.cfg.classes
- if self.backbone not in ["Yolov3", "SSD_Lite"]:
- self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
- model = init_detector(self.cfg, checkpoint, device=device)
- model.CLASSES = self.cfg.classes
- else:
- model = init_detector(self.cfg, self.checkpoint, device=device)
- if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rpn.nms.iou_threshold = 0.3 # 1 - rpn_threshold
- if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rcnn.nms.iou_threshold = 0.3 # 1 - rcnn_threshold
- #detector = build_detector(self.cfg.model, test_cfg=self.cfg.get('test_cfg'))
- # detector.CLASSES = self.num_classes
- normalize_cfg = parse_normalize_cfg(self.cfg.test_pipeline)
- input_img = osp.join(osp.dirname(__file__), './demo/demo.jpg')
- if backend == "ONNX" or backend == 'onnx':
- pytorch2onnx(
- # detector,
- model,
- input_img,
- input_shape,
- normalize_cfg,
- show=False,
- output_file=out_file,
- verify=False,
- test_img=None,
- do_simplify=False)
- else:
- print("Sorry, we only suport ONNX up to now.")
- with open(out_file.replace(".onnx", ".py"), "w+") as f:
- tp = str(self.cfg.test_pipeline).replace("},","},\n\t")
- # if class_path != None:
- # classes_list = self.get_class(class_path)
- classes_list = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
- gen0 = """
- import onnxruntime as rt
- import BaseData
- import numpy as np
- import cv2
- cap = cv2.VideoCapture(0)
- ret_flag,image = cap.read()
- cap.release()
- """
- gen_sz = """
- image = cv2.resize(image,(sz_h,sz_w))
- tag =
- """
- gen1 = """
- sess = rt.InferenceSession('
- """
- gen2 = """', None)
- input_name = sess.get_inputs()[0].name
- output_names = [o.name for o in sess.get_outputs()]
- dt = BaseData.ImageData(image, backbone="
- """
- gen3 = """")
- input_data = dt.to_tensor()
- outputs = sess.run(output_names, {input_name: input_data})
- boxes = outputs[0]
- labels = outputs[1][0]
- img_height, img_width = image.shape[:2]
- size = min([img_height, img_width]) * 0.001
- text_thickness = int(min([img_height, img_width]) * 0.001)
- idx = 0
- for box in zip(boxes[0]):
- x1, y1, x2, y2, score = box[0]
- label = tag[labels[idx]]
- idx = idx + 1
- caption = f'{label}{int(score * 100)}%'
- if score >= 0.15:
- (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
- fontScale=size, thickness=text_thickness)
- th = int(th * 1.2)
- cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
- cv2.putText(image, caption, (int(x1), int(y1)),
- cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)
- cv2.imwrite("result.jpg", image)
- """
- ashape = self.cfg.test_pipeline[1].img_scale
- # if class_path != None:
- gen = gen0.strip("\n") + '\n' + gen_sz.replace('sz_h',str(ashape[0])).replace('sz_w',str(ashape[1])).strip('\n') + str(classes_list)+ "\n" + gen1.strip("\n") + out_file + gen2.strip("\n") + str(self.backbone) + gen3
- # else:
- # gen = gen0.strip("tag = \n") + "\n\n" + gen1.strip("\n")+out_file+ gen2.strip("\n") + str(self.backbone) + gen3.replace("tag[labels[idx]]", "labels[idx]")
- f.write(gen)
- def parse_normalize_cfg(test_pipeline):
- transforms = None
- for pipeline in test_pipeline:
- if 'transforms' in pipeline:
- transforms = pipeline['transforms']
- break
- assert transforms is not None, 'Failed to find `transforms`'
- norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
- assert len(norm_config_li) == 1, '`norm_config` should only have one'
- norm_config = norm_config_li[0]
- return norm_config
- def pytorch2onnx(model,
- input_img,
- input_shape,
- normalize_cfg,
- opset_version=11,
- show=False,
- output_file='tmp.onnx',
- verify=False,
- test_img=None,
- do_simplify=False,
- dynamic_export=None,
- skip_postprocess=False):
- from mmdet.core.export import build_model_from_cfg, preprocess_example_input
- from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
- from functools import partial
- from mmcv import Config, DictAction
- import onnx
- input_config = {
- 'input_shape': input_shape,
- 'input_path': input_img,
- 'normalize_cfg': normalize_cfg
- }
- # prepare input
- one_img, one_meta = preprocess_example_input(input_config)
- img_list, img_meta_list = [one_img], [[one_meta]]
- if skip_postprocess:
- warnings.warn('Not all models support export onnx without post '
- 'process, especially two stage detectors!')
- model.forward = model.forward_dummy
- torch.onnx.export(
- model,
- one_img,
- output_file,
- input_names=['input'],
- export_params=True,
- keep_initializers_as_inputs=True,
- do_constant_folding=True,
- verbose=show,
- opset_version=opset_version)
- print(f'Successfully exported ONNX model without '
- f'post process: {output_file}')
- return
- # replace original forward function
- origin_forward = model.forward
- model.forward = partial(
- model.forward,
- img_metas=img_meta_list,
- return_loss=False,
- rescale=False)
- output_names = ['dets', 'labels']
- if model.with_mask:
- output_names.append('masks')
- input_name = 'input'
- dynamic_axes = None
- if dynamic_export:
- dynamic_axes = {
- input_name: {
- 0: 'batch',
- 2: 'height',
- 3: 'width'
- },
- 'dets': {
- 0: 'batch',
- 1: 'num_dets',
- },
- 'labels': {
- 0: 'batch',
- 1: 'num_dets',
- },
- }
- if model.with_mask:
- dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
- torch.onnx.export(
- model,
- img_list,
- output_file,
- input_names=[input_name],
- output_names=output_names,
- export_params=True,
- keep_initializers_as_inputs=True,
- do_constant_folding=True,
- verbose=show,
- opset_version=opset_version,
- dynamic_axes=dynamic_axes)
- model.forward = origin_forward
- if do_simplify:
- import onnxsim
- from mmdet import digit_version
- min_required_version = '0.4.0'
- assert digit_version(onnxsim.__version__) >= digit_version(
- min_required_version
- ), f'Requires to install onnxsim>={min_required_version}'
- model_opt, check_ok = onnxsim.simplify(output_file)
- if check_ok:
- onnx.save(model_opt, output_file)
- print(f'Successfully simplified ONNX model: {output_file}')
- else:
- warnings.warn('Failed to simplify ONNX model.')
- print(f'Successfully exported ONNX model: {output_file}')
- if verify:
- # check by onnx
- onnx_model = onnx.load(output_file)
- onnx.checker.check_model(onnx_model)
- #print(model.CLASSES)
- # wrap onnx model
- onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
- if dynamic_export:
- # scale up to test dynamic shape
- h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
- h, w = min(1344, h), min(1344, w)
- input_config['input_shape'] = (1, 3, h, w)
- if test_img is None:
- input_config['input_path'] = input_img
- # prepare input once again
- one_img, one_meta = preprocess_example_input(input_config)
- img_list, img_meta_list = [one_img], [[one_meta]]
- # get pytorch output
- with torch.no_grad():
- pytorch_results = model(
- img_list,
- img_metas=img_meta_list,
- return_loss=False,
- rescale=True)[0]
- img_list = [_.cuda().contiguous() for _ in img_list]
- if dynamic_export:
- img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
- img_meta_list = img_meta_list * 2
- # get onnx output
- onnx_results = onnx_model(
- img_list, img_metas=img_meta_list, return_loss=False)[0]
- # visualize predictions
- score_thr = 0.3
- if show:
- out_file_ort, out_file_pt = None, None
- else:
- out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
- show_img = one_meta['show_img']
- model.show_result(
- show_img,
- pytorch_results,
- score_thr=score_thr,
- show=True,
- win_name='PyTorch',
- out_file=out_file_pt)
- onnx_model.show_result(
- show_img,
- onnx_results,
- score_thr=score_thr,
- show=True,
- win_name='ONNXRuntime',
- out_file=out_file_ort)
- # compare a part of result
- if model.with_mask:
- compare_pairs = list(zip(onnx_results, pytorch_results))
- else:
- compare_pairs = [(onnx_results, pytorch_results)]
- err_msg = 'The numerical values are different between Pytorch' + \
- ' and ONNX, but it does not necessarily mean the' + \
- ' exported ONNX model is problematic.'
- # check the numerical value
- for onnx_res, pytorch_res in compare_pairs:
- for o_res, p_res in zip(onnx_res, pytorch_res):
- np.testing.assert_allclose(
- o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
- print('The numerical values are the same between Pytorch and ONNX')
用MMCls训练手势模型的更多相关文章
- PocketSphinx语音识别系统语言模型的训练和声学模型的改进
PocketSphinx语音识别系统语言模型的训练和声学模型的改进 zouxy09@qq.com http://blog.csdn.net/zouxy09 关于语音识别的基础知识和sphinx的知识, ...
- keras训练cnn模型时loss为nan
keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...
- 搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型
原文地址:搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型 0x00 环境 OS: Ubuntu 1810 x64 Anaconda: 4.6.12 P ...
- 在Java Web中使用Spark MLlib训练的模型
PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用.目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨 ...
- 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)
原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...
- 文本主题抽取:用gensim训练LDA模型
得知李航老师的<统计学习方法>出了第二版,我第一时间就买了.看了这本书的目录,非常高兴,好家伙,居然把主题模型都写了,还有pagerank.一路看到了马尔科夫蒙特卡罗方法和LDA主题模型这 ...
- tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛
tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls0 1 10 11 12 13 14 15 ...
- java web应用调用python深度学习训练的模型
之前参见了中国软件杯大赛,在大赛中用到了深度学习的相关算法,也训练了一些简单的模型.项目线上平台是用java编写的web应用程序,而深度学习使用的是python语言,这就涉及到了在java代码中调用p ...
- 基于Caffe训练AlexNet模型
数据集 1.准备数据集 1)下载训练和验证图片 ImageNet官网地址:http://www.image-net.org/signup.php?next=download-images (需用邮箱注 ...
- TensorFlow 训练好模型参数的保存和恢复代码
TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...
随机推荐
- 智能合约HardHat框架环境的搭建
1.首先创建一个npm项目 PS C:\Users\lcds\blockchainprojects> mkdir hardhatcontract PS C:\Users\lcds\blockch ...
- JavaWeb之Servlet详解(以及浏览器调用 Servlet 流程分析图)
Servlet 1.什么是Servlet Servlet(java 服务器小程序) 他是由服务器端调用和执行的(一句话:是Tomcat解析和执行) 他是用java语言编写的, 本质就是Java类 他是 ...
- Java面试题全集(一)
JDK.JRE.JVM之间的区别 JDK(Java SE Development Kit),Java标准开发包,它提供了编译.运⾏Java程序所需的各种⼯具和资源,包括Java编译器.Java运⾏时环 ...
- ES 实战复杂sql查询、修改字段类型
转载请注明出处: 1.查询索引得 mapping 与 setting get 直接查询 索引名称时,会返回 该 索引得 mapping 和 settings 得配置,上述返回得结构如下: { &quo ...
- 如何使用C#中的Lambda表达式操作Redis Hash结构,简化缓存中对象属性的读写操作
Redis是一个开源的.高性能的.基于内存的键值数据库,它支持多种数据结构,如字符串.列表.集合.散列.有序集合等.其中,Redis的散列(Hash)结构是一个常用的结构,今天跟大家分享一个我的日常操 ...
- 巧用 awk 批量杀进程
今天遇到线上的一个问题: 我需要批量杀死某台机器的 PHP 进程,该怎么办? 注意,不是 php-fpm,是常驻任务. 如果是一个进程,那就好办了,ps -ef | grep php,找到 PID 然 ...
- Unity UGUI的CanvasScaler(画布缩放器)组件的介绍及使用
Unity UGUI的CanvasScaler(画布缩放器)组件的介绍及使用 1. 什么是CanvasScaler组件? CanvasScaler是Unity中UGUI系统中的一个组件,用于控制画布的 ...
- 国产化之x64平台安装银河麒麟操作系统
背景 某个项目需要实现基础软件全部国产化,其中操作系统指定银河麒麟v4,CPU使用飞腾处理器.飞腾处理器是ARMv8架构的,在之前的文章中介绍了使用QEMU模拟ARMv8架构安装银河麒麟操作系统的方式 ...
- python添加水印
# coding:utf-8 from PIL import Image, ImageDraw, ImageFont def add_text_to_image(image, text): font ...
- Linux 压缩文件用法
# tar 命令:可以用来压缩或解压缩文件: # 压缩 tar -czvf filename.tar.gz files # 解压缩 tar -xzvf filename.tar.gz # gzip 命 ...