image_test.py

import argparse
import numpy as np
import sys
import os
import csv
from imagenet_test_base import TestKit
import torch class TestTorch(TestKit): def __init__(self):
super(TestTorch, self).__init__() self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] self.model = self.MainModel.KitModel(self.args.w)
self.model.eval() def preprocess(self, image_path):
x = super(TestTorch, self).preprocess(image_path)
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0).copy()
self.data = torch.from_numpy(x)
self.data = torch.autograd.Variable(self.data, requires_grad = False) def print_result(self, image_name, top1, top5):
predict = self.model(self.data)
predict = predict.data.numpy()
return super(TestTorch, self).print_result(predict, image_name, top1, top5) def print_intermediate_result(self, layer_name, if_transpose=False):
intermediate_output = self.model.test.data.numpy()
super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) def inference(self, images): with open(images) as fp_images:
images_file = csv.reader(fp_images, delimiter='\n')
top1 = 0.0
top5 = 0.0
image_count = 0
for image_name in images_file:
image_count += 1
image_path = "../data/imagenet/small_imagenet/"+image_name[0]
self.preprocess(image_path)
temp1, temp5 = self.print_result(image_name[0], top1, top5)
top1 = temp1
top5 = temp5
print("top1's accuracy : %f"%(top1/image_count))
print("top5's accuracy : %f"%(top5/image_count))
# self.print_intermediate_result(None, False)
# self.test_truth() def dump(self, path=None):
if path is None: path = self.args.dump
torch.save(self.model, path)
print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
path, self.args.n, self.args.w)) if __name__=='__main__':
tester = TestTorch()
if tester.args.dump:
tester.dump()
else:
tester.inference(tester.args.image)

image_test_base.py:

  请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz

执行py文件时,需要终端输入参数:

 parser = argparse.ArgumentParser()

        parser.add_argument('-p', '--preprocess', type=_text_type, help='Model Preprocess Type')   # pytorch的测试程序, 这里为image_test.py

        parser.add_argument('-n', type=_text_type, default='kit_imagenet',
help='Network structure file name.') # 模型结构测试文件 parser.add_argument('-s', type=_text_type, help='Source Framework Type',
choices=self.truth.keys()) # 框架类型:pytorch,tensorflow... parser.add_argument('-w', type=_text_type, required=True,
help='Network weights file name') #模型结构文件 parser.add_argument('--image', '-i',
type=_text_type, help='Test image path.',
default="../data/file_list.txt" #图像路径
) parser.add_argument('-l', '--label',
type=_text_type,
default='../data/val.txt',
help='Path of label.') #测试集类别 parser.add_argument('--dump',
type=_text_type,
default=None,
help='Target model path.') # 转化的目标模型文件的保存路径 parser.add_argument('--detect',
type=_text_type,
default=None,
help='Model detection result path.') # tensorflow dump tag
parser.add_argument('--dump_tag',
type=_text_type,
default=None,
help='Tensorflow model dump type',
choices=['SERVING', 'TRAINING'])

pytorch imagenet测试代码的更多相关文章

  1. .NET单元测试的艺术-3.测试代码

    开篇:上一篇我们学习单元测试和核心技术:存根.模拟对象和隔离框架,它们是我们进行高质量单元测试的技术基础.本篇会集中在管理和组织单元测试的技术,以及如何确保在真实项目中进行高质量的单元测试. 系列目录 ...

  2. mysql锁 实战测试代码

    存储引擎 支持的锁定 MyISAM 表级锁 MEMORY 表级锁 InnoDB 行级锁 BDB 页面锁 表级锁:开销小,加锁快:不会出现死锁:锁定粒度大,发生锁冲突的概率最高,并发度最低.行级锁:开销 ...

  3. 使用Microsoft Fakes隔离测试代码

    在单元测试(Unit Test)中我们遇到的问题之一是:假如被测试组件(类或项目)为A,组件A依赖于组件B,那么在组件A的单元测试ATest中测试A时,也需要依赖于B,在B发生改动后,就可能影响到A的 ...

  4. iOS开发:XCTest单元测试(附上一个单例的测试代码)

    测试驱动开发并不是一个很新鲜的概念了.在我最开始学习程序编写时,最喜欢干的事情就是编写一段代码,然后运行观察结果是否正确.我所学习第一门语言是c语言,用的最多的是在算法设计上,那时候最常做的事情就是编 ...

  5. 在内核中异步请求设备固件firmware的测试代码

    在内核中异步请求设备固件firmware的测试代码 static void ghost_load_firmware_callback(const struct firmware *fw, void * ...

  6. x264测试代码

    建立一个工程,将头文件,库文件加载到工程,测试代码如下:#include <iostream>#include <string>#include "stdint.h& ...

  7. Android网络传输中必用的两个加密算法:MD5 和 RSA (附java完成测试代码)

    MD5和RSA是网络传输中最常用的两个算法,了解这两个算法原理后就能大致知道加密是怎么一回事了.但这两种算法使用环境有差异,刚好互补. 一.MD5算法 首先MD5是不可逆的,只能加密而不能解密.比如明 ...

  8. Git合并开发代码分支到测试代码分支

    ——转载请注明出自天外归云的博客园 用TortoiseGit下载代码到本地 首先需要在本机安装好TortoiseGit.然后在随便哪个路径下比如D盘,右键“Git Clone”: 然后URL处选择项目 ...

  9. mvn编写主代码与测试代码

    maven编写主代码与测试代码 3.2 编写主代码 项目主代码和测试代码不同,项目的主代码会被打包到最终的构件中(比如jar),而测试代码只在运行测试时用到,不会被打包.默认情况下,Maven假设项目 ...

随机推荐

  1. [高清] Spring揭秘完整高清版

      ------ 郑重声明 --------- 资源来自网络,纯粹共享交流, 如果喜欢,请您务必支持正版!! --------------------------------------------- ...

  2. 分享大麦UWP版本开发历程-03.GridView或ListView 滚动底部自动加载后续数据

    今天跟大家分享的是大麦UWP客户端,在分类.订单或是搜索时都用到的一个小技巧,技术粗糙大神勿喷. 以大麦分类举例,默认打开的时候,会为用户展示20条数据,当用户滚动鼠标或者使用手势将列表滑动到倒数第二 ...

  3. Windows中的消息与消息队列

    消息 在Windows中,消自由MSG结构体表示 typedef struct tagMSG { HWND hwnd; UINT message; WPARAM wParam; LPARAM lPar ...

  4. Python之TensorFlow的基本介绍-1

    一.TensorFlow™是一个基于数据流编程(dataflow programming)的符号数学系统,被广泛应用于各类机器学习(machine learning)算法的编程实现,其前身是谷歌的神经 ...

  5. iOS-右滑返回,利用Runtime添加全屏Pop手势

    项目中经常会遇到类似需求,需要在某控制器增加全屏右滑返回功能. 在我们不隐藏 NavigationBar 的前提下,系统会自动替我增加此功能,只是它作用的范围仅仅在屏幕左边有限区域. 我们需要在整个界 ...

  6. lxterminal命令打开新窗口并执行python脚本

    lxterminal -e python3 -i test.py 注意,路径要写对,用绝对路径

  7. 版本管理工具Git三种工作流

    Git是分布式版本管理控制的工具.学习Git一般都是先去学习Git的命令. 但是学习完Git的基本命令之后还是不知道怎样使用Git.首先,我们要清楚的 一点是Git的使用方法其实有很多种,也就是说Gi ...

  8. 二十三、mysql索引管理详解

    一.索引分类 分为聚集索引和非聚集索引. 聚集索引 每个表有且一定会有一个聚集索引,整个表的数据存储在聚集索引中,mysql索引是采用B+树结构保存在文件中,叶子节点存储主键的值以及对应记录的数据,非 ...

  9. ViewBag---MVC3中 ViewBag、ViewData和TempData的使用和差别-------与ViewBag+Hashtable应用例子

    ViewBag 在MVC3開始.视图数据能够通过ViewBag属性訪问.在MVC2中则是使用ViewData.MVC3中保留了ViewData的使用.ViewBag 是动态类型(dynamic),Vi ...

  10. 将Maven项目部署云服务器流程

    1.数据库分离,存入项目: 2.将分离出的数据库导入云端服务器 将sql文件上传到服务器中 进去云端数据库输入命令:source  云服务器中sql文件地址 3.设置两种配置,修改匹配: 4.mave ...