1. import os
  2. import json
  3. import mmcv
  4. import time
  5. from mmcv import Config
  6. from mmdet.apis import inference_detector, init_detector, show_result_pyplot, train_detector
  7. from mmdet.models import build_detector
  8. from mmdet.datasets import build_dataset
  9. from mmcv.runner import load_checkpoint
  10. from tqdm import tqdm
  11. import warnings
  12. import torch
  13. import numpy as np
  14. import cv2
  15.  
  16. warnings.filterwarnings("ignore")
  17.  
  18. class MMDetection:
  19. def sota():
  20. pypath = os.path.abspath(__file__)
  21. father = os.path.dirname(pypath)
  22. models = os.path.join(father, 'models')
  23. sota_model = []
  24. for i in os.listdir(models):
  25. if i[0] != '_':
  26. sota_model.append(i)
  27. return sota_model
  28.  
  29. def __init__(self,
  30. backbone='FasterRCNN',
  31. num_classes=-1,
  32. dataset_path=None
  33. ):
  34.  
  35. # 获取外部运行py的绝对路径
  36. self.cwd = os.path.dirname(os.getcwd())
  37. # 获取当前文件的绝对路径
  38. self.file_dirname = os.path.dirname(os.path.abspath(__file__))
  39. self.save_fold = None
  40. self.is_sample = False
  41. self.config = os.path.join(
  42. self.file_dirname, 'models', 'FasterRCNN/FasterRCNN.py')
  43. self.checkpoint = os.path.join(
  44. self.file_dirname, 'models', '/FasterRCNN/FasterRCNN.pth')
  45.  
  46. self.backbone = backbone
  47. backbone_path = os.path.join(
  48. self.file_dirname, 'models', self.backbone)
  49. ckpt_cfg_list = list(os.listdir(backbone_path))
  50. for item in ckpt_cfg_list:
  51. if item[-1] == 'y' and item[0] != '_': # pip包修改1
  52. self.config = os.path.join(backbone_path, item)
  53. elif item[-1] == 'h':
  54. self.checkpoint = os.path.join(backbone_path, item)
  55. else:
  56. # print("Warning!!! There is an unrecognized file in the backbone folder.")
  57. pass
  58.  
  59. self.cfg = Config.fromfile(self.config)
  60.  
  61. self.dataset_path = dataset_path
  62. self.lr = None
  63. self.backbonedict = {
  64. "FasterRCNN": os.path.join(self.file_dirname, 'models', 'FasterRCNN/FasterRCNN.py'),
  65. "Yolov3": os.path.join(self.file_dirname, 'models', 'Yolov3/Yolov3.py'),
  66. "SSD_Lite":os.path.join(self.file_dirname, 'models', 'SSD_Lite/SSD_Lite.py'),
  67. # "Mask_RCNN":os.path.join(self.file_dirname, 'models', 'Mask_RCNN/Mask_RCNN.py'),
  68. # 下略
  69. }
  70. self.num_classes = num_classes
  71. self.chinese_res = None
  72. self.is_sample = False
  73.  
  74. def train(self, random_seed=0, save_fold=None, distributed=False, validate=True, device='cpu',
  75. metric='bbox', save_best='bbox_mAP', optimizer="SGD", epochs=100, lr=0.001, weight_decay=0.001,
  76. Frozen_stages=1,
  77. checkpoint=None, batch_size=None):
  78.  
  79. # 加载网络模型的配置文件
  80. self.cfg = Config.fromfile(self.backbonedict[self.backbone])
  81.  
  82. # 如果外部不指定save_fold
  83. if not self.save_fold:
  84. # 如果外部也没有传入save_fold,我们使用默认路径
  85. if not save_fold:
  86. self.save_fold = os.path.join(
  87. self.cwd, 'checkpoints/det_model')
  88. # 如果外部传入save_fold,我们使用传入值
  89. else:
  90. self.save_fold = save_fold
  91.  
  92. self.cfg.model.backbone.frozen_stages = Frozen_stages
  93.  
  94. if self.num_classes != -1:
  95. if "RCNN" not in self.backbone: # 单阶段
  96. self.cfg.model.bbox_head.num_classes =self.num_classes
  97. elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
  98. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  99. elif self.backbone == "Mask_RCNN":
  100. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  101. self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
  102.  
  103. self.load_dataset(self.dataset_path)
  104. # 添加需要进行检测的类名
  105. if self.backbone in ["Yolov3"]:
  106. self.cfg.classes = self.get_classes(self.cfg.data.train.dataset.ann_file)
  107. else:
  108. self.cfg.classes = self.get_classes(self.cfg.data.train.ann_file)
  109.  
  110. # 分别为训练、测试、验证添加类名
  111. if self.backbone in ["Yolov3"]:
  112. self.cfg.data.train.dataset.classes = self.cfg.classes
  113. else:
  114. self.cfg.data.train.classes = self.cfg.classes
  115. self.cfg.data.test.classes = self.cfg.classes
  116. self.cfg.data.val.classes = self.cfg.classes
  117.  
  118. # 进行
  119. self.cfg.work_dir = self.save_fold
  120. # 创建工作目录
  121. mmcv.mkdir_or_exist(os.path.abspath(self.cfg.work_dir))
  122. # 创建分类器
  123. datasets = [build_dataset(self.cfg.data.train)]
  124. model = build_detector(self.cfg.model, train_cfg=self.cfg.get(
  125. 'train_cfg'), test_cfg=self.cfg.get('test_cfg'))
  126. # print("checkpoint", checkpoint)
  127. if not checkpoint:
  128. model.init_weights()
  129. else:
  130. checkpoint = os.path.abspath(checkpoint) # pip修改2
  131. load_checkpoint(model, checkpoint, map_location=torch.device(device))
  132.  
  133. model.CLASSES = self.cfg.classes
  134. if optimizer == 'Adam':
  135. self.cfg.optimizer = dict(type='Adam', lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0001)
  136. elif optimizer == 'Adagrad':
  137. self.cfg.optimizer = dict(type='Adagrad', lr=lr, lr_decay=0)
  138. # 根据输入参数更新config文件
  139. self.cfg.optimizer.lr = lr # 学习率
  140. self.cfg.optimizer.type = optimizer # 优化器
  141. self.cfg.optimizer.weight_decay = weight_decay # 优化器的衰减权重
  142. self.cfg.evaluation.metric = metric # 验证指标
  143. # self.cfg.evaluation.save_best = save_best
  144. self.cfg.runner.max_epochs = epochs # 最大的训练轮次
  145.  
  146. # 设置每 5 个训练批次输出一次日志
  147. # self.cfg.log_config.interval = 1
  148. self.cfg.gpu_ids = range(1)
  149.  
  150. self.cfg.seed = random_seed
  151. if batch_size is not None:
  152. self.cfg.data.samples_per_gpu = batch_size
  153. train_detector(
  154. model,
  155. datasets,
  156. self.cfg,
  157. distributed=distributed,
  158. validate=validate,
  159. timestamp=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
  160. meta=dict()
  161. )
  162.  
  163. def print_result(self, res=None):
  164. if self.is_sample == True:
  165. print("示例检测结果如下:")
  166. sample_result = r"[{'类别标签': 0, '置信度': 1.0, '坐标': {'x': 26, 'y': 81, 'w': 497, 'h': 414}},{'类别标签': 1, '置信度': 1.0, '坐标': {'x1': 250, 'y1': 103, 'x2': 494, 'y2': 341}}]"
  167. print(sample_result)
  168. else:
  169. print("检测结果如下:")
  170. print(self.chinese_res)
  171. return self.chinese_res
  172.  
  173. def load_checkpoint(self, checkpoint=None, device='cpu',
  174. rpn_threshold=0.7, rcnn_threshold=0.7):
  175. print("========= begin inference ==========")
  176. if self.num_classes != -1 and self.backbone not in ["Yolov3", "SSD_Lite"]:
  177. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  178.  
  179. if checkpoint:
  180. # 加载数据集及配置文件的路径
  181. # self.load_dataset(self.dataset_path)
  182. # 修正检测的目标
  183. # self.cfg.classes = self.get_class(class_path)
  184. self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
  185. self.cfg.data.train.classes = self.cfg.classes
  186. self.cfg.data.test.classes = self.cfg.classes
  187. self.cfg.data.val.classes = self.cfg.classes
  188. if "RCNN" not in self.backbone: # 单阶段
  189. self.cfg.model.bbox_head.num_classes = len(self.cfg.classes)
  190. else: # rcnn系列 双阶段
  191. self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
  192. # self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
  193. self.infer_model = init_detector(self.cfg, checkpoint, device=device)
  194. self.infer_model.CLASSES = self.cfg.classes
  195. else:
  196. self.infer_model = init_detector(self.cfg, self.checkpoint, device=device)
  197. if self.backbone not in ["Yolov3", "SSD_Lite"]: self.infer_model.test_cfg.rpn.nms.iou_threshold = 1 - rpn_threshold
  198. if self.backbone not in ["Yolov3", "SSD_Lite"]: self.infer_model.test_cfg.rcnn.nms.iou_threshold = 1 - rcnn_threshold
  199.  
  200. def fast_inference(self, image, show=False, save_fold='det_result'):
  201. img_array = mmcv.imread(image)
  202. try:
  203. self.infer_model
  204. except:
  205. print("请先使用load_checkpoint()方法加载权重!")
  206. return
  207. result = inference_detector(self.infer_model, img_array) # 此处的model和外面的无关,纯局部变量
  208. self.infer_model.show_result(image, result, show=show,
  209. out_file=os.path.join(save_fold, os.path.split(image)[1]))
  210. chinese_res = []
  211. for i in range(len(result)):
  212. for j in range(result[i].shape[0]):
  213. tmp = {}
  214. tmp['类别标签'] = i
  215. tmp['置信度'] = result[i][j][4]
  216. tmp['坐标'] = {"x": int(result[i][j][0]), "y": int(
  217. result[i][j][1]), 'w': int(result[i][j][2]), 'h': int(result[i][j][3])}
  218. # img.append(tmp)
  219. chinese_res.append(tmp)
  220. # print(chinese_res)
  221. self.chinese_res = chinese_res
  222. # print("========= finish inference ==========")
  223. return result
  224.  
  225. def inference(self, device='cpu',
  226. checkpoint=None,
  227. image=None,
  228. show=True,
  229. rpn_threshold=0.7,
  230. rcnn_threshold=0.7,
  231. save_fold='det_result',
  232. ):
  233. # self.cfg.classes = self.get_class(class_path)
  234. # self.num_classes = len(self.get_class(class_path))
  235. # if self.num_classes != -1:
  236. # if "RCNN" not in self.backbone: # 单阶段
  237. # self.cfg.model.bbox_head.num_classes =self.num_classes
  238. # elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
  239. # self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  240. # elif self.backbone == "Mask_RCNN":
  241. # self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  242. # self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
  243.  
  244. if image == None:
  245. self.is_sample = True
  246. sample_return = """
  247. [array([[ 26.547777 , 81.55447 , 497.37015 , 414.4934 ,
  248. 1.0]], dtype=float32),
  249. array([[2.5098564e+02, 1.0334784e+02, 4.9422855e+02, 3.4187744e+02,
  250. 1.0], dtype=float32)]
  251. """
  252. return sample_return
  253. self.is_sample = False
  254. print("========= begin inference ==========")
  255.  
  256. if self.num_classes != -1 and self.backbone not in ["Yolov3", "SSD_Lite"] :
  257. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  258.  
  259. if checkpoint:
  260. # 加载数据集及配置文件的路径
  261. # self.load_dataset(self.dataset_path)
  262. # 修正检测的目标
  263. self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
  264.  
  265. self.num_classes = len(self.cfg.classes)
  266. if self.num_classes != -1:
  267. if "RCNN" not in self.backbone: # 单阶段
  268. self.cfg.model.bbox_head.num_classes =self.num_classes
  269. elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
  270. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  271. elif self.backbone == "Mask_RCNN":
  272. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  273. self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
  274. self.cfg.data.train.classes = self.cfg.classes
  275. self.cfg.data.test.classes = self.cfg.classes
  276. self.cfg.data.val.classes = self.cfg.classes
  277. if self.backbone not in ["Yolov3", "SSD_Lite"]:
  278. self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
  279. model = init_detector(self.cfg, checkpoint, device=device)
  280. model.CLASSES = self.cfg.classes
  281. else:
  282. model = init_detector(self.cfg, self.checkpoint, device=device)
  283. # model = build_detector(self.cfg.model, train_cfg=self.cfg.get(
  284. # 'train_cfg'), test_cfg=self.cfg.get('test_cfg'))
  285. # if not checkpoint:
  286. # model.init_weights()
  287. # else:
  288. # checkpoint = os.path.abspath(checkpoint) # pip修改2
  289. # load_checkpoint(model, checkpoint, map_location=torch.device(device))
  290.  
  291. if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rpn.nms.iou_threshold = 1 - rpn_threshold
  292. if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rcnn.nms.iou_threshold = 1 - rcnn_threshold
  293.  
  294. results = []
  295. if (image[-1] != '/'):
  296. img_array = mmcv.imread(image)
  297. result = inference_detector(
  298. model, img_array) # 此处的model和外面的无关,纯局部变量
  299. if show == True:
  300. show_result_pyplot(model, image, result)
  301. model.show_result(image, result, show=show, out_file=os.path.join(save_fold, os.path.split(image)[1]))
  302. chinese_res = []
  303. for i in range(len(result)):
  304. for j in range(result[i].shape[0]):
  305. tmp = {}
  306. tmp['类别标签'] = i
  307. tmp['置信度'] = result[i][j][4]
  308. tmp['坐标'] = {"x1": int(result[i][j][0]), "y1": int(
  309. result[i][j][1]), 'x2': int(result[i][j][2]), 'y2': int(result[i][j][3])}
  310. # img.append(tmp)
  311. chinese_res.append(tmp)
  312. # print(chinese_res)
  313. self.chinese_res = chinese_res
  314. print("========= finish inference ==========")
  315. return result
  316. else:
  317. img_dir = image
  318. mmcv.mkdir_or_exist(os.path.abspath(save_fold))
  319. chinese_results = []
  320. for i, img in enumerate(tqdm(os.listdir(img_dir))):
  321. result = inference_detector(
  322. model, img_dir + img) # 此处的model和外面的无关,纯局部变量
  323. model.show_result(img_dir + img, result,
  324. out_file=os.path.join(save_fold, img))
  325. chinese_res = []
  326. for i in range(len(result)):
  327. for j in range(result[i].shape[0]):
  328. tmp = {}
  329. tmp['类别标签'] = i
  330. tmp['置信度'] = result[i][j][4]
  331. tmp['坐标'] = {"x1": int(result[i][j][0]), "y1": int(
  332. result[i][j][1]), 'x2': int(result[i][j][2]), 'y2': int(result[i][j][3])}
  333. # img.append(tmp)
  334. chinese_res.append(tmp)
  335. chinese_results.append(chinese_res)
  336. results.append(result)
  337. self.chinese_res = chinese_results
  338. print("========= finish inference ==========")
  339. return results
  340.  
  341. def load_dataset(self, path):
  342. self.dataset_path = path
  343.  
  344. # 数据集修正为coco格式
  345. if self.backbone in ["Yolov3"]:
  346. self.cfg.data.train.dataset.img_prefix = os.path.join(self.dataset_path, 'images/train/')
  347. self.cfg.data.train.dataset.ann_file = os.path.join(self.dataset_path, 'annotations/train.json')
  348. else:
  349. self.cfg.data.train.img_prefix = os.path.join(self.dataset_path, 'images/train/')
  350. self.cfg.data.train.ann_file = os.path.join(self.dataset_path, 'annotations/train.json')
  351.  
  352. self.cfg.data.val.img_prefix = os.path.join(self.dataset_path, 'images/test/')
  353. self.cfg.data.val.ann_file = os.path.join(self.dataset_path, 'annotations/valid.json')
  354.  
  355. self.cfg.data.test.img_prefix = os.path.join(self.dataset_path, 'images/test/')
  356. self.cfg.data.test.ann_file = os.path.join(self.dataset_path, 'annotations/valid.json')
  357.  
  358. def get_class(self, class_path):
  359. classes = []
  360. with open(class_path, 'r') as f:
  361. for name in f:
  362. classes.append(name.strip('\n'))
  363. return classes
  364.  
  365. def get_classes(self, annotation_file):
  366. classes = ()
  367. with open(annotation_file, 'r') as f:
  368. dataset = json.load(f)
  369. # categories = dataset["categories"]
  370. if 'categories' in dataset:
  371. for cat in dataset['categories']:
  372. classes = classes + (cat['name'],)
  373. return classes
  374.  
  375. def convert(self, checkpoint=None, backend="ONNX", out_file="convert_model.onnx",device='cpu'):
  376. import os.path as osp
  377. from mmdet.core.export import build_model_from_cfg
  378.  
  379. ashape = self.cfg.test_pipeline[1].img_scale
  380. if len(ashape) == 1:
  381. input_shape = (1, 3, ashape[0], ashape[0])
  382. elif len(ashape) == 2:
  383. input_shape = (
  384. 1,
  385. 3,
  386. ) + tuple(ashape)
  387. else:
  388. raise ValueError('invalid input shape')
  389. self.cfg.model.pretrained = None
  390. if self.backbone not in ["Yolov3", "SSD_Lite"] :
  391. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  392. else:
  393. self.cfg.model.bbox_head.num_classes = self.num_classes
  394.  
  395. # build the model and load checkpoint
  396. # detector = build_detector(self.cfg.model)
  397. # model = build_model_from_cfg(self.config, checkpoint)
  398.  
  399. if checkpoint:
  400. # 加载数据集及配置文件的路径
  401. # self.load_dataset(self.dataset_path)
  402. # 修正检测的目标
  403. self.cfg.classes = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
  404. self.num_classes = len(self.cfg.classes)
  405. if self.num_classes != -1:
  406. if "RCNN" not in self.backbone: # 单阶段
  407. self.cfg.model.bbox_head.num_classes =self.num_classes
  408. elif self.backbone == "FasterRCNN": # rcnn系列 双阶段
  409. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  410. elif self.backbone == "Mask_RCNN":
  411. self.cfg.model.roi_head.bbox_head.num_classes = self.num_classes
  412. self.cfg.model.roi_head.mask_head.num_classes = self.num_classes
  413. self.cfg.data.train.classes = self.cfg.classes
  414. self.cfg.data.test.classes = self.cfg.classes
  415. self.cfg.data.val.classes = self.cfg.classes
  416. if self.backbone not in ["Yolov3", "SSD_Lite"]:
  417. self.cfg.model.roi_head.bbox_head.num_classes = len(self.cfg.classes)
  418. model = init_detector(self.cfg, checkpoint, device=device)
  419. model.CLASSES = self.cfg.classes
  420. else:
  421. model = init_detector(self.cfg, self.checkpoint, device=device)
  422. if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rpn.nms.iou_threshold = 0.3 # 1 - rpn_threshold
  423. if self.backbone not in ["Yolov3", "SSD_Lite"]: model.test_cfg.rcnn.nms.iou_threshold = 0.3 # 1 - rcnn_threshold
  424.  
  425. #detector = build_detector(self.cfg.model, test_cfg=self.cfg.get('test_cfg'))
  426. # detector.CLASSES = self.num_classes
  427. normalize_cfg = parse_normalize_cfg(self.cfg.test_pipeline)
  428. input_img = osp.join(osp.dirname(__file__), './demo/demo.jpg')
  429. if backend == "ONNX" or backend == 'onnx':
  430. pytorch2onnx(
  431. # detector,
  432. model,
  433. input_img,
  434. input_shape,
  435. normalize_cfg,
  436. show=False,
  437. output_file=out_file,
  438. verify=False,
  439. test_img=None,
  440. do_simplify=False)
  441. else:
  442. print("Sorry, we only suport ONNX up to now.")
  443. with open(out_file.replace(".onnx", ".py"), "w+") as f:
  444. tp = str(self.cfg.test_pipeline).replace("},","},\n\t")
  445. # if class_path != None:
  446. # classes_list = self.get_class(class_path)
  447. classes_list = torch.load(checkpoint, map_location=torch.device('cpu'))['meta']['CLASSES']
  448.  
  449. gen0 = """
  450. import onnxruntime as rt
  451. import BaseData
  452. import numpy as np
  453. import cv2
  454.  
  455. cap = cv2.VideoCapture(0)
  456. ret_flag,image = cap.read()
  457. cap.release()
  458. """
  459. gen_sz = """
  460. image = cv2.resize(image,(sz_h,sz_w))
  461. tag =
  462. """
  463. gen1 = """
  464. sess = rt.InferenceSession('
  465. """
  466. gen2 = """', None)
  467. input_name = sess.get_inputs()[0].name
  468. output_names = [o.name for o in sess.get_outputs()]
  469. dt = BaseData.ImageData(image, backbone="
  470. """
  471.  
  472. gen3 = """")
  473. input_data = dt.to_tensor()
  474. outputs = sess.run(output_names, {input_name: input_data})
  475.  
  476. boxes = outputs[0]
  477. labels = outputs[1][0]
  478. img_height, img_width = image.shape[:2]
  479. size = min([img_height, img_width]) * 0.001
  480. text_thickness = int(min([img_height, img_width]) * 0.001)
  481.  
  482. idx = 0
  483. for box in zip(boxes[0]):
  484. x1, y1, x2, y2, score = box[0]
  485. label = tag[labels[idx]]
  486. idx = idx + 1
  487. caption = f'{label}{int(score * 100)}%'
  488. if score >= 0.15:
  489. (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
  490. fontScale=size, thickness=text_thickness)
  491. th = int(th * 1.2)
  492. cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
  493. cv2.putText(image, caption, (int(x1), int(y1)),
  494. cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)
  495.  
  496. cv2.imwrite("result.jpg", image)
  497. """
  498. ashape = self.cfg.test_pipeline[1].img_scale
  499. # if class_path != None:
  500. gen = gen0.strip("\n") + '\n' + gen_sz.replace('sz_h',str(ashape[0])).replace('sz_w',str(ashape[1])).strip('\n') + str(classes_list)+ "\n" + gen1.strip("\n") + out_file + gen2.strip("\n") + str(self.backbone) + gen3
  501. # else:
  502. # gen = gen0.strip("tag = \n") + "\n\n" + gen1.strip("\n")+out_file+ gen2.strip("\n") + str(self.backbone) + gen3.replace("tag[labels[idx]]", "labels[idx]")
  503. f.write(gen)
  504.  
  505. def parse_normalize_cfg(test_pipeline):
  506. transforms = None
  507. for pipeline in test_pipeline:
  508. if 'transforms' in pipeline:
  509. transforms = pipeline['transforms']
  510. break
  511. assert transforms is not None, 'Failed to find `transforms`'
  512. norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
  513. assert len(norm_config_li) == 1, '`norm_config` should only have one'
  514. norm_config = norm_config_li[0]
  515. return norm_config
  516.  
  517. def pytorch2onnx(model,
  518. input_img,
  519. input_shape,
  520. normalize_cfg,
  521. opset_version=11,
  522. show=False,
  523. output_file='tmp.onnx',
  524. verify=False,
  525. test_img=None,
  526. do_simplify=False,
  527. dynamic_export=None,
  528. skip_postprocess=False):
  529. from mmdet.core.export import build_model_from_cfg, preprocess_example_input
  530. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  531. from functools import partial
  532. from mmcv import Config, DictAction
  533. import onnx
  534. input_config = {
  535. 'input_shape': input_shape,
  536. 'input_path': input_img,
  537. 'normalize_cfg': normalize_cfg
  538. }
  539. # prepare input
  540. one_img, one_meta = preprocess_example_input(input_config)
  541. img_list, img_meta_list = [one_img], [[one_meta]]
  542.  
  543. if skip_postprocess:
  544. warnings.warn('Not all models support export onnx without post '
  545. 'process, especially two stage detectors!')
  546. model.forward = model.forward_dummy
  547. torch.onnx.export(
  548. model,
  549. one_img,
  550. output_file,
  551. input_names=['input'],
  552. export_params=True,
  553. keep_initializers_as_inputs=True,
  554. do_constant_folding=True,
  555. verbose=show,
  556. opset_version=opset_version)
  557.  
  558. print(f'Successfully exported ONNX model without '
  559. f'post process: {output_file}')
  560. return
  561.  
  562. # replace original forward function
  563. origin_forward = model.forward
  564. model.forward = partial(
  565. model.forward,
  566. img_metas=img_meta_list,
  567. return_loss=False,
  568. rescale=False)
  569.  
  570. output_names = ['dets', 'labels']
  571. if model.with_mask:
  572. output_names.append('masks')
  573. input_name = 'input'
  574. dynamic_axes = None
  575. if dynamic_export:
  576. dynamic_axes = {
  577. input_name: {
  578. 0: 'batch',
  579. 2: 'height',
  580. 3: 'width'
  581. },
  582. 'dets': {
  583. 0: 'batch',
  584. 1: 'num_dets',
  585. },
  586. 'labels': {
  587. 0: 'batch',
  588. 1: 'num_dets',
  589. },
  590. }
  591. if model.with_mask:
  592. dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
  593.  
  594. torch.onnx.export(
  595. model,
  596. img_list,
  597. output_file,
  598. input_names=[input_name],
  599. output_names=output_names,
  600. export_params=True,
  601. keep_initializers_as_inputs=True,
  602. do_constant_folding=True,
  603. verbose=show,
  604. opset_version=opset_version,
  605. dynamic_axes=dynamic_axes)
  606.  
  607. model.forward = origin_forward
  608.  
  609. if do_simplify:
  610. import onnxsim
  611.  
  612. from mmdet import digit_version
  613.  
  614. min_required_version = '0.4.0'
  615. assert digit_version(onnxsim.__version__) >= digit_version(
  616. min_required_version
  617. ), f'Requires to install onnxsim>={min_required_version}'
  618.  
  619. model_opt, check_ok = onnxsim.simplify(output_file)
  620. if check_ok:
  621. onnx.save(model_opt, output_file)
  622. print(f'Successfully simplified ONNX model: {output_file}')
  623. else:
  624. warnings.warn('Failed to simplify ONNX model.')
  625. print(f'Successfully exported ONNX model: {output_file}')
  626.  
  627. if verify:
  628. # check by onnx
  629. onnx_model = onnx.load(output_file)
  630. onnx.checker.check_model(onnx_model)
  631. #print(model.CLASSES)
  632. # wrap onnx model
  633. onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
  634. if dynamic_export:
  635. # scale up to test dynamic shape
  636. h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
  637. h, w = min(1344, h), min(1344, w)
  638. input_config['input_shape'] = (1, 3, h, w)
  639.  
  640. if test_img is None:
  641. input_config['input_path'] = input_img
  642.  
  643. # prepare input once again
  644. one_img, one_meta = preprocess_example_input(input_config)
  645. img_list, img_meta_list = [one_img], [[one_meta]]
  646.  
  647. # get pytorch output
  648. with torch.no_grad():
  649. pytorch_results = model(
  650. img_list,
  651. img_metas=img_meta_list,
  652. return_loss=False,
  653. rescale=True)[0]
  654.  
  655. img_list = [_.cuda().contiguous() for _ in img_list]
  656. if dynamic_export:
  657. img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
  658. img_meta_list = img_meta_list * 2
  659. # get onnx output
  660. onnx_results = onnx_model(
  661. img_list, img_metas=img_meta_list, return_loss=False)[0]
  662. # visualize predictions
  663. score_thr = 0.3
  664. if show:
  665. out_file_ort, out_file_pt = None, None
  666. else:
  667. out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
  668.  
  669. show_img = one_meta['show_img']
  670. model.show_result(
  671. show_img,
  672. pytorch_results,
  673. score_thr=score_thr,
  674. show=True,
  675. win_name='PyTorch',
  676. out_file=out_file_pt)
  677. onnx_model.show_result(
  678. show_img,
  679. onnx_results,
  680. score_thr=score_thr,
  681. show=True,
  682. win_name='ONNXRuntime',
  683. out_file=out_file_ort)
  684.  
  685. # compare a part of result
  686. if model.with_mask:
  687. compare_pairs = list(zip(onnx_results, pytorch_results))
  688. else:
  689. compare_pairs = [(onnx_results, pytorch_results)]
  690. err_msg = 'The numerical values are different between Pytorch' + \
  691. ' and ONNX, but it does not necessarily mean the' + \
  692. ' exported ONNX model is problematic.'
  693. # check the numerical value
  694. for onnx_res, pytorch_res in compare_pairs:
  695. for o_res, p_res in zip(onnx_res, pytorch_res):
  696. np.testing.assert_allclose(
  697. o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
  698. print('The numerical values are the same between Pytorch and ONNX')

用MMCls训练手势模型的更多相关文章

  1. PocketSphinx语音识别系统语言模型的训练和声学模型的改进

    PocketSphinx语音识别系统语言模型的训练和声学模型的改进 zouxy09@qq.com http://blog.csdn.net/zouxy09 关于语音识别的基础知识和sphinx的知识, ...

  2. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  3. 搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型

    原文地址:搭建 MobileNet-SSD 开发环境并使用 VOC 数据集训练 TensorFlow 模型 0x00 环境 OS: Ubuntu 1810 x64 Anaconda: 4.6.12 P ...

  4. 在Java Web中使用Spark MLlib训练的模型

    PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用.目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨 ...

  5. 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)

    原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...

  6. 文本主题抽取:用gensim训练LDA模型

    得知李航老师的<统计学习方法>出了第二版,我第一时间就买了.看了这本书的目录,非常高兴,好家伙,居然把主题模型都写了,还有pagerank.一路看到了马尔科夫蒙特卡罗方法和LDA主题模型这 ...

  7. tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛

    tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls0  1  10  11  12  13  14  15 ...

  8. java web应用调用python深度学习训练的模型

    之前参见了中国软件杯大赛,在大赛中用到了深度学习的相关算法,也训练了一些简单的模型.项目线上平台是用java编写的web应用程序,而深度学习使用的是python语言,这就涉及到了在java代码中调用p ...

  9. 基于Caffe训练AlexNet模型

    数据集 1.准备数据集 1)下载训练和验证图片 ImageNet官网地址:http://www.image-net.org/signup.php?next=download-images (需用邮箱注 ...

  10. TensorFlow 训练好模型参数的保存和恢复代码

    TensorFlow 训练好模型参数的保存和恢复代码,之前就在想模型不应该每次要个结果都要重新训练一遍吧,应该训练一次就可以一直使用吧. TensorFlow 提供了 Saver 类,可以进行保存和恢 ...

随机推荐

  1. 智能合约HardHat框架环境的搭建

    1.首先创建一个npm项目 PS C:\Users\lcds\blockchainprojects> mkdir hardhatcontract PS C:\Users\lcds\blockch ...

  2. JavaWeb之Servlet详解(以及浏览器调用 Servlet 流程分析图)

    Servlet 1.什么是Servlet Servlet(java 服务器小程序) 他是由服务器端调用和执行的(一句话:是Tomcat解析和执行) 他是用java语言编写的, 本质就是Java类 他是 ...

  3. Java面试题全集(一)

    JDK.JRE.JVM之间的区别 JDK(Java SE Development Kit),Java标准开发包,它提供了编译.运⾏Java程序所需的各种⼯具和资源,包括Java编译器.Java运⾏时环 ...

  4. ES 实战复杂sql查询、修改字段类型

    转载请注明出处: 1.查询索引得 mapping 与 setting get 直接查询 索引名称时,会返回 该 索引得 mapping 和 settings 得配置,上述返回得结构如下: { &quo ...

  5. 如何使用C#中的Lambda表达式操作Redis Hash结构,简化缓存中对象属性的读写操作

    Redis是一个开源的.高性能的.基于内存的键值数据库,它支持多种数据结构,如字符串.列表.集合.散列.有序集合等.其中,Redis的散列(Hash)结构是一个常用的结构,今天跟大家分享一个我的日常操 ...

  6. 巧用 awk 批量杀进程

    今天遇到线上的一个问题: 我需要批量杀死某台机器的 PHP 进程,该怎么办? 注意,不是 php-fpm,是常驻任务. 如果是一个进程,那就好办了,ps -ef | grep php,找到 PID 然 ...

  7. Unity UGUI的CanvasScaler(画布缩放器)组件的介绍及使用

    Unity UGUI的CanvasScaler(画布缩放器)组件的介绍及使用 1. 什么是CanvasScaler组件? CanvasScaler是Unity中UGUI系统中的一个组件,用于控制画布的 ...

  8. 国产化之x64平台安装银河麒麟操作系统

    背景 某个项目需要实现基础软件全部国产化,其中操作系统指定银河麒麟v4,CPU使用飞腾处理器.飞腾处理器是ARMv8架构的,在之前的文章中介绍了使用QEMU模拟ARMv8架构安装银河麒麟操作系统的方式 ...

  9. python添加水印

    # coding:utf-8 from PIL import Image, ImageDraw, ImageFont def add_text_to_image(image, text): font ...

  10. Linux 压缩文件用法

    # tar 命令:可以用来压缩或解压缩文件: # 压缩 tar -czvf filename.tar.gz files # 解压缩 tar -xzvf filename.tar.gz # gzip 命 ...