内容接前文:

https://www.cnblogs.com/devilmaycry812839668/p/14988686.html

https://www.cnblogs.com/devilmaycry812839668/p/14990021.html

前面是我们自己按照个人理解实现的单步计算,随着对这个计算框架MindSpore的深入了解我们了解到其实官方是提供了单步计算函数的。

具体函数:

from mindspore.nn import TrainOneStepCell, WithLossCell

根据官方资料:

https://www.mindspore.cn/doc/programming_guide/zh-CN/master/network_component.html?highlight=%E5%8D%95%E6%AD%A5%E8%AE%AD%E7%BB%83

根据官方提供的函数,给出如下代码:

import mindspore
import numpy as np # 引入numpy科学计算库
import matplotlib.pyplot as plt # 引入绘图库 np.random.seed(123) # 随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn import TrainOneStepCell, WithLossCell class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1) + self.bias_1
x2 = self.matmul(x1, self.weight_2) + self.bias_2
x3 = self.matmul(x2, self.weight_3) + self.bias_3
return x3 def main():
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
# make net model
# model = Model(net, loss, optim, metrics={'loss': nn.Loss()})
net_with_criterion = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_criterion, optim) # 数据集
x, y = np.array([[0.1]], dtype=np.float32), np.array([[0.1]], dtype=np.float32)
x = Tensor(x)
y = Tensor(y) for i in range(20000*100):
#print(i, '\t', '*' * 100)
train_network.set_train()
res = train_network(x, y) # right
# False, False
# False, True
# True, True xxx # not right
# True, False if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context # 为mindspore设置运行背景context
#context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time a = time.time()
main()
b = time.time()
print(b-a)

运行时间:

1158.24s

1154.29s

1152.69s

=====================================================

前文我们给出的单步计算 model.train  的代码修改如下:

import mindspore
import numpy as np # 引入numpy科学计算库
import matplotlib.pyplot as plt # 引入绘图库 np.random.seed(123) # 随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1) + self.bias_1
x2 = self.matmul(x1, self.weight_2) + self.bias_2
x3 = self.matmul(x2, self.weight_3) + self.bias_3
return x3 def main():
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
# make net model
model = Model(net, loss, optim, metrics={'loss': nn.Loss()}) # 数据集
x, y = np.array([[0.1]], dtype=np.float32), np.array([[0.1]], dtype=np.float32) def generator_multidimensional():
for i in range(1):
a = x*i
b = y*i
#print(a, b)
yield (a, b) dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["input", "output"]) for i in range(20000*100):
#print(i, '\t', '*' * 100)
model.train(1, dataset, dataset_sink_mode=False) # right
# False, False
# False, True
# True, True xxx # not right
# True, False if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context # 为mindspore设置运行背景context
#context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time a = time.time()
main()
b = time.time()
print(b-a)

运行时间:

2173.19s

2181.61s

==================================================================

可以看到,在单步计算时,如果使用框架提供的单步训练函数会更好的提升算法运算效率,运算效率提升的幅度也很大,所有在进行单步训练或者非持续数据量训练时使用框架提供的单步训练函数是首选。

单步训练函数:

from mindspore.nn import TrainOneStepCell, WithLossCell

=====================================================================

本文实验环境为  MindSpore1.1  docker版本

宿主机:Ubuntu18.04系统

CPU:I7-8700

GPU:1060ti NVIDIA显卡

(续 2 )在深度计算框架MindSpore中如何对不持续的计算进行处理——对数据集进行一定epoch数量的训练后,进行其他工作处理,再返回来接着进行一定epoch数量的训练——单步计算的更多相关文章

  1. 带你学习MindSpore中算子使用方法

    摘要:本文分享下MindSpore中算子的使用和遇到问题时的解决方法. 本文分享自华为云社区<[MindSpore易点通]算子使用问题与解决方法>,作者:chengxiaoli. 简介 算 ...

  2. TensorFlow - 框架实现中的三种 Graph

    文章目录 TensorFlow - 框架实现中的三种 Graph 1. Graph 2. GraphDef 3. MetaGraph 4. Checkpoint 5. 总结 TensorFlow - ...

  3. SSH框架应用中常用Jar包用途介绍

    struts2需要的几个jar包:1)xwork-core-2.1.62)struts2-core-2.1.83)ognl-2.7.34)freemarker-2.3.155)commons-io-1 ...

  4. 如何在Crystal框架项目中内置启动MetaQ服务?

    当Crystal框架项目中需要使用消息机制,而项目规模不大.性能要求不高时,可内置启动MetaQ服务器. 分步指南 项目引入crystal-extend-metaq模块,如下: <depende ...

  5. 如何在Crystal框架项目中内置启动Zookeeper服务?

    当Crystal框架项目需要使用到Zookeeper服务时(如使用Dubbo RPC时,需要注册服务到Zookeeper),而独立部署和启动Zookeeper服务不仅繁琐,也容易出现错误. 在小型项目 ...

  6. 浅入深出之Java集合框架(中)

    Java中的集合框架(中) 由于Java中的集合框架的内容比较多,在这里分为三个部分介绍Java的集合框架,内容是从浅到深,如果已经有java基础的小伙伴可以直接跳到<浅入深出之Java集合框架 ...

  7. Javscript调用iframe框架页面中函数的方法

    Javscript调用iframe框架页面中函数的方法,可以实现iframe之间传值或修改值了, 访问iframe里面的函数: window.frames['CallCenter_iframe'].h ...

  8. 游戏框架设计中的。绑定binding。。。命令 command 和消息message 以及MVVM

    游戏框架设计中的.绑定binding...命令 command 和消息message

  9. 关于MFC框架程序中CWinApp::OnIdle

    很早之前就发现,我写的图形引擎在MFC框架程序中的刷帧率始终在60FPS左右.好在自己的程序对刷帧率的要求不是很高,所以一直没有太过纠结此事.直到今天看了别人的程序才发现应该在函数CWinApp::O ...

  10. TP框架模板中IF Else 如何使用?

    TP框架模板中IF Else 如何使用? 截个图吧 如果效果出不来,一般就是条件写错了!!!

随机推荐

  1. DoNet Core的启动过程-WebApplicationBuilder

    1.前言 在NET6开始做ASP.NETCore的开发,我们首先要看的是启动过程,而WebApplication和WebApplicationBuilder 类是启动过程好不开的类,WebApplic ...

  2. 透过 node-exporter 彻底弄懂机器监控:01. node-exporter 框架讲解

    前言 Prometheus 生态里有很多采集器负责各类监控数据的采集,其中使用最广泛的,显然是 node-exporter,负责 Linux.BSD 等系统的常规监控指标的采集,比如 CPU.内存.硬 ...

  3. nginx虚拟主机实战

    基于nginx部署网站 虚拟主机指的就是一个独立的站点,具有独立的域名,有完整的www服务,例如网站.FTP.邮件等. Nginx支持多虚拟主机,在一台机器上可以运行完全独立的多个站点. 一.为什么配 ...

  4. 《Android开发卷——ListView嵌套GridView(基础)》

      listview嵌套gridview,最主要应该解决的问题是listview跟GridView的滑动问题.这个利用GridView是自定义的,就是让GridView内容有多大就显示多大,然后禁用他 ...

  5. java并发编程——CompletableFuture

    简介 Java的java.util.concurrent包中提供了并发相关的接口和类,本文将重点介绍CompletableFuture并发操作类 JDK1.8新增CompletableFuture该类 ...

  6. java线程的park unpark方法

    标签(空格分隔): 多线程 park 和 unpark的使用 park和unpark并不是线程的方法,而是LockSupport的静态方法 暂停当前线程 LockSupport.park();//所在 ...

  7. SOP页面跳转设计 RAS AES加密算法应用跨服务免登陆接口设计

    SOP页面跳转设计 RAS AES加密算法应用跨服务免登陆接口设计 SOP,是 Standard Operating Procedure三个单词中首字母的大写 ,即标准作业程序,指将某一事件的标准操作 ...

  8. sqlUtil

    package com.cmbchina.monitor.utils;import com.alibaba.druid.sql.ast.SQLStatement;import com.alibaba. ...

  9. VUE CLI中使用Jquery无法获取到dom节点

    mounted 类型:Function 详细: 实例被挂载后调用,这时 el 被新创建的 vm.$el 替换了.如果根实例挂载到了一个文档内的元素上,当 mounted 被调用时 vm.$el 也在文 ...

  10. 虚拟 DOM 的优缺点?

    什么是虚拟dom用js模拟一颗dom树,放在浏览器内存中.当你要变更时,虚拟dom使用diff算法进行新旧虚拟dom的比较,将变更放到变更队列中, 反应到实际的dom树,减少了dom操作. 虚拟DOM ...