前文:

https://www.cnblogs.com/devilmaycry812839668/p/14975860.html

前文中我们使用自己编写的损失函数和单步梯度求导来实现算法,这里是作为扩展,我们这里使用系统提供的损失函数和优化器,代码如下:

import mindspore
import numpy as np #引入numpy科学计算库
import matplotlib.pyplot as plt #引入绘图库
np.random.seed(123) #随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor # 数据集
class DatasetGenerator:
def __init__(self):
self.input_data = 2*np.random.rand(500, 1).astype(np.float32)
self.output_data = 5+3*self.input_data+np.random.randn(500, 1).astype(np.float32) def __getitem__(self, index):
return self.input_data[index], self.output_data[index] def __len__(self):
return len(self.input_data) def create_dataset(batch_size=500):
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["input", "output"], shuffle=False) #buffer_size = 10000
#dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=True)
#dataset = dataset.repeat(1)
return dataset class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1)+self.bias_1
x2 = self.matmul(x1, self.weight_2)+self.bias_2
x3 = self.matmul(x2, self.weight_3)+self.bias_3
return x3 def main():
epochs = 10000
dataset = create_dataset()
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
model = Model(net, loss, optim, metrics={'loss': nn.Loss()}) config_ck = CheckpointConfig(save_checkpoint_steps=10000, keep_checkpoint_max=10)
ck_point = ModelCheckpoint(prefix="checkpoint_mlp", config=config_ck)
model.train(epochs, dataset, callbacks=[ck_point, LossMonitor(1000)], dataset_sink_mode=False) np.random.seed(123) # 随机数生成种子
dataset = create_dataset()
data = next(dataset.create_dict_iterator())
x = data['input']
y = data['output']
y_hat = model.predict(x) eval_loss = model.eval(dataset, dataset_sink_mode=False)
print("{}".format(eval_loss)) fig=plt.figure(figsize=(8,6))#确定画布大小
plt.title("Dataset")#标题名
plt.xlabel("First feature")#x轴的标题
plt.ylabel("Second feature")#y轴的标题
plt.scatter(x.asnumpy(), y.asnumpy())#设置为散点图
plt.scatter(x.asnumpy(), y_hat.asnumpy())#设置为散点图
plt.show()#绘制出来 if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context
# 为mindspore设置运行背景context
# context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time
a = time.time()
main()
b = time.time()
print(b-a)

输入结果:

epoch: 1000 step: 1, loss is 1.0764071
epoch: 2000 step: 1, loss is 1.0253074
epoch: 3000 step: 1, loss is 1.0251387
epoch: 4000 step: 1, loss is 1.0251383
epoch: 5000 step: 1, loss is 1.0251379
epoch: 6000 step: 1, loss is 1.0251367
epoch: 7000 step: 1, loss is 1.0251377
epoch: 8000 step: 1, loss is 1.0251374
epoch: 9000 step: 1, loss is 1.0251373
epoch: 10000 step: 1, loss is 1.025136
{'loss': 1.0251359939575195}
62.462594985961914

可以看到使用系统提供的损失函数和优化器同样可以实现算法,但是很神奇的地方是算法的整体运算时间没有降低反而增加了,由前文中的41秒左右提高到了62秒左右,具体原因这里就不知道了,毕竟mindspore框架内部有很多其他的实现,如自动并行之类的,因此它的表现我们这里难以分析,这里只是作为功能尝试而已。

(续) MindSpore 如何实现一个线性回归 —— Demo示例的更多相关文章

  1. LeadTools Android 入门教学——运行第一个Android Demo

    LeadTools 有很多Windows平台下的Demo,非常全面,但是目前开发手机应用的趋势也越来越明显,LeadTools也给大家提供了10个Android的Demo,这篇文章将会教你如何运行第一 ...

  2. 一个数据源demo

    前言 我们重复造轮子,不是为了证明我们比那些造轮子的人牛逼,而是明白那些造轮子的人有多牛逼. JDBC介绍 在JDBC中,我们可以通过DriverManager.getConnection()创建(而 ...

  3. 快速搭建一个直播Demo

    缘由 最近帮朋友看一个直播网站的源码,发现这份直播源码借助 阿里云 .腾讯云这些大公司提供的SDK 可以非常方便的搭建一个直播网站.下面我们来给大家讲解下如何借助 腾讯云 我们搭建一个简易的 直播示例 ...

  4. Visual Studio 2017 - Windows应用程序打包成exe文件(2)- Advanced Installer 关于Newtonsoft.Json,LINQ to JSON的一个小demo mysql循环插入数据、生成随机数及CONCAT函数 .NET记录-获取外网IP以及判断该IP是属于网通还是电信 Guid的生成和数据修整(去除空格和小写字符)

    Visual Studio 2017 - Windows应用程序打包成exe文件(2)- Advanced Installer   Advanced Installer :Free for 30 da ...

  5. 【分享】Vue 资源典藏(UI组件、开发框架、服务端、辅助工具、应用实例、Demo示例)

    Vue 资源典藏,包括:UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例 element ★11612 - 饿了么出品的Vue2的web UI工具套件 Vux ★7503 - 基于Vue和 ...

  6. kafka_2.11-0.8.2.1+java 生产消费程序demo示例

      Kafka学习8_kafka java 生产消费程序demo示例 kafka是吞吐量巨大的一个消息系统,它是用scala写的,和普通的消息的生产消费还有所不同,写了个demo程序供大家参考.kaf ...

  7. SpringBoot整合Swagger2(Demo示例)

    写在前面 由于公司项目采用前后端分离,维护接口文档基本上是必不可少的工作.一个理想的状态是设计好后,接口文档发给前端和后端,大伙按照既定的规则各自开发,开发好了对接上了就可以上线了.当然这是一种非常理 ...

  8. Vue UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例

    Vue UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例 element ★11612 - 饿了么出品的Vue2的web UI工具套件 Vux ★7503 - 基于Vue和WeUI的组件库 ...

  9. Go学习【02】:理解Gin,搭一个web demo

    Go Gin 框架 说Gin是一个框架,不如说Gin是一个类库或者工具库,其包含了可以组成框架的组件.这样会更好理解一点. 举个 下面的示例代码在这:github 利用Gin组成最基本的框架.说到框架 ...

  10. ArcGIS API for JavaScript开发环境搭建及第一个实例demo

    原文:ArcGIS API for JavaScript开发环境搭建及第一个实例demo ESRI公司截止到目前已经发布了最新的ArcGIS Server for JavaScript API v3. ...

随机推荐

  1. MYSQL 连接数据库过程中发生错误,检查服务器是否正常连接字符串是否正确,错误信息:未将对象引用设置到对象的实例。

    一: 中文提示 : 连接数据库过程中发生错误,检查服务器是否正常连接字符串是否正确,错误信息:未将对象引用设置到对象的实例.DbType="MySql";ConfigId=&quo ...

  2. 使用 Promise.withResolvers() 来简化你将函数 Promise 化的实现~~

    引言 在JavaScript编程中,Promise 是一种处理异步操作的常用机制.Promise 对象代表了一个尚未完成但预期将来会完成的操作的结果.在本文中,我们将探讨如何通过使用 ES2024 的 ...

  3. Mysql中innodb的B+tree能存储多少数据?

    引言 InnoDB一棵3层B+树可以存放多少行数据?这个问题的简单回答是:约2千万.为什么是这么多呢?因为这是可以算出来的,要搞清楚这个问题,我们先从InnoDB索引数据结构.数据组织方式说起. 在计 ...

  4. 盘点 Spring Boot 解决跨域请求的几种办法

    熟悉 web 系统开发的同学,对下面这样的错误应该不会太陌生. 之所以会出现这个错误,是因为浏览器出于安全的考虑,采用同源策略的控制,防止当前站点恶意攻击 web 服务器盗取数据. 01.什么是跨域请 ...

  5. RHEL8.1---离线升级gcc

    升级gcc到gcc9.1.0 下载离线包.放到/opt下 [root@172-18-251-35 opt]# wget http://ftp.gnu.org/gnu/gcc/gcc-9.1.0/gcc ...

  6. 09-CentOS软件包管理

    简介 CentOS7使用rpm和yum来管理软件包. CentOS 8附带YUM包管理器v4.0.4版本,该版本现在使用DNF (Dandified YUM)技术作为后端.DNF是新一代的YUM,新的 ...

  7. Android自动化无障碍服务开源库-Assists v3.0.0

    Assists v3.0.0 Android无障碍服务(AccessibilityService)开发框架,快速开发复杂自动化任务.远程协助.监听等 Android无障碍服务能做什么 利用Androi ...

  8. 2.SpringBoot快速上手

    2.SpringBoot快速上手 SpringBoot介绍 javaEE的开发经常会涉及到3个框架Spring ,SpringMVC,MyBatis.但是这三个框架配置极其繁琐,有大量的xml文件,s ...

  9. 静态 top tree 入门

    理论 我们需要一个数据结构维护树上的问题,仿照序列上的问题,我们需要一个方法快速的刻画出信息. 比如说线段树就通过分治的方式来通过将一个区间划分成 \(\log n\) 个区间并刻画出这 \(\log ...

  10. P9120 题解

    暴力容斥复活之路! \(k=1\) 这个你肯定会. \(k=2\) 大的放上去,小的放下来.简单贪心. \(k=3\) 考虑二分答案. 然后考虑判断是否合法. 令当前答案为 \(val\). 首先钦定 ...