前文:

https://www.cnblogs.com/devilmaycry812839668/p/14975860.html

前文中我们使用自己编写的损失函数和单步梯度求导来实现算法,这里是作为扩展,我们这里使用系统提供的损失函数和优化器,代码如下:

import mindspore
import numpy as np #引入numpy科学计算库
import matplotlib.pyplot as plt #引入绘图库
np.random.seed(123) #随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor # 数据集
class DatasetGenerator:
def __init__(self):
self.input_data = 2*np.random.rand(500, 1).astype(np.float32)
self.output_data = 5+3*self.input_data+np.random.randn(500, 1).astype(np.float32) def __getitem__(self, index):
return self.input_data[index], self.output_data[index] def __len__(self):
return len(self.input_data) def create_dataset(batch_size=500):
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["input", "output"], shuffle=False) #buffer_size = 10000
#dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=True)
#dataset = dataset.repeat(1)
return dataset class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1)+self.bias_1
x2 = self.matmul(x1, self.weight_2)+self.bias_2
x3 = self.matmul(x2, self.weight_3)+self.bias_3
return x3 def main():
epochs = 10000
dataset = create_dataset()
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
model = Model(net, loss, optim, metrics={'loss': nn.Loss()}) config_ck = CheckpointConfig(save_checkpoint_steps=10000, keep_checkpoint_max=10)
ck_point = ModelCheckpoint(prefix="checkpoint_mlp", config=config_ck)
model.train(epochs, dataset, callbacks=[ck_point, LossMonitor(1000)], dataset_sink_mode=False) np.random.seed(123) # 随机数生成种子
dataset = create_dataset()
data = next(dataset.create_dict_iterator())
x = data['input']
y = data['output']
y_hat = model.predict(x) eval_loss = model.eval(dataset, dataset_sink_mode=False)
print("{}".format(eval_loss)) fig=plt.figure(figsize=(8,6))#确定画布大小
plt.title("Dataset")#标题名
plt.xlabel("First feature")#x轴的标题
plt.ylabel("Second feature")#y轴的标题
plt.scatter(x.asnumpy(), y.asnumpy())#设置为散点图
plt.scatter(x.asnumpy(), y_hat.asnumpy())#设置为散点图
plt.show()#绘制出来 if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context
# 为mindspore设置运行背景context
# context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time
a = time.time()
main()
b = time.time()
print(b-a)

输入结果:

epoch: 1000 step: 1, loss is 1.0764071
epoch: 2000 step: 1, loss is 1.0253074
epoch: 3000 step: 1, loss is 1.0251387
epoch: 4000 step: 1, loss is 1.0251383
epoch: 5000 step: 1, loss is 1.0251379
epoch: 6000 step: 1, loss is 1.0251367
epoch: 7000 step: 1, loss is 1.0251377
epoch: 8000 step: 1, loss is 1.0251374
epoch: 9000 step: 1, loss is 1.0251373
epoch: 10000 step: 1, loss is 1.025136
{'loss': 1.0251359939575195}
62.462594985961914

可以看到使用系统提供的损失函数和优化器同样可以实现算法,但是很神奇的地方是算法的整体运算时间没有降低反而增加了,由前文中的41秒左右提高到了62秒左右,具体原因这里就不知道了,毕竟mindspore框架内部有很多其他的实现,如自动并行之类的,因此它的表现我们这里难以分析,这里只是作为功能尝试而已。

(续) MindSpore 如何实现一个线性回归 —— Demo示例的更多相关文章

  1. LeadTools Android 入门教学——运行第一个Android Demo

    LeadTools 有很多Windows平台下的Demo,非常全面,但是目前开发手机应用的趋势也越来越明显,LeadTools也给大家提供了10个Android的Demo,这篇文章将会教你如何运行第一 ...

  2. 一个数据源demo

    前言 我们重复造轮子,不是为了证明我们比那些造轮子的人牛逼,而是明白那些造轮子的人有多牛逼. JDBC介绍 在JDBC中,我们可以通过DriverManager.getConnection()创建(而 ...

  3. 快速搭建一个直播Demo

    缘由 最近帮朋友看一个直播网站的源码,发现这份直播源码借助 阿里云 .腾讯云这些大公司提供的SDK 可以非常方便的搭建一个直播网站.下面我们来给大家讲解下如何借助 腾讯云 我们搭建一个简易的 直播示例 ...

  4. Visual Studio 2017 - Windows应用程序打包成exe文件(2)- Advanced Installer 关于Newtonsoft.Json,LINQ to JSON的一个小demo mysql循环插入数据、生成随机数及CONCAT函数 .NET记录-获取外网IP以及判断该IP是属于网通还是电信 Guid的生成和数据修整(去除空格和小写字符)

    Visual Studio 2017 - Windows应用程序打包成exe文件(2)- Advanced Installer   Advanced Installer :Free for 30 da ...

  5. 【分享】Vue 资源典藏(UI组件、开发框架、服务端、辅助工具、应用实例、Demo示例)

    Vue 资源典藏,包括:UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例 element ★11612 - 饿了么出品的Vue2的web UI工具套件 Vux ★7503 - 基于Vue和 ...

  6. kafka_2.11-0.8.2.1+java 生产消费程序demo示例

      Kafka学习8_kafka java 生产消费程序demo示例 kafka是吞吐量巨大的一个消息系统,它是用scala写的,和普通的消息的生产消费还有所不同,写了个demo程序供大家参考.kaf ...

  7. SpringBoot整合Swagger2(Demo示例)

    写在前面 由于公司项目采用前后端分离,维护接口文档基本上是必不可少的工作.一个理想的状态是设计好后,接口文档发给前端和后端,大伙按照既定的规则各自开发,开发好了对接上了就可以上线了.当然这是一种非常理 ...

  8. Vue UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例

    Vue UI组件 开发框架 服务端 辅助工具 应用实例 Demo示例 element ★11612 - 饿了么出品的Vue2的web UI工具套件 Vux ★7503 - 基于Vue和WeUI的组件库 ...

  9. Go学习【02】:理解Gin,搭一个web demo

    Go Gin 框架 说Gin是一个框架,不如说Gin是一个类库或者工具库,其包含了可以组成框架的组件.这样会更好理解一点. 举个 下面的示例代码在这:github 利用Gin组成最基本的框架.说到框架 ...

  10. ArcGIS API for JavaScript开发环境搭建及第一个实例demo

    原文:ArcGIS API for JavaScript开发环境搭建及第一个实例demo ESRI公司截止到目前已经发布了最新的ArcGIS Server for JavaScript API v3. ...

随机推荐

  1. SQL SERVER 同一台服务器,A库正常连接,B库提示“等待的操作过时”

    SQL SERVER 同一台服务器,A库正常连接,B库提示"等待的操作过时" 解决方法: 在客户端(非SQL SERVER 服务器)用管理员身份运行CMD,输入netsh wins ...

  2. 短链接口设计&禁用Springboot执行器端点/env的安全性

    短链接口设计 //短链接服务 跳转方式,实现短链接转长链接的请求. @GetMapping("/{code}") public String redirectUrl(@PathVa ...

  3. elasticsearch6.8 ik分词器需安装

    elasticsearch6.8  ik分词器需安装order_info_es/_analyze POST{ "analyzer": "ik_max_word" ...

  4. xshell和xftp下载免费版本方法

    下载地址 https://www.xshell.com/zh/free-for-home-school/ 填写邮箱和用户名,会发送下载邮件到邮箱,然后根据邮箱中的下载地址来下载安装.

  5. C#.Net筑基-集合知识全解

    01.集合基础知识 .Net 中提供了一系列的管理对象集合的类型,数组.可变列表.字典等.从类型安全上集合分为两类,泛型集合 和 非泛型集合,传统的非泛型集合存储为Object,需要类型转.而泛型集合 ...

  6. MySQL bit类型增加索引后查询结果不正确案例浅析

    昨天同事遇到的一个案例,这里简单描述一下:一个表里面有一个bit类型的字段,同事在优化相关SQL的过程中,给这个表的bit类型的字段新增了一个索引,然后测试验证 时,居然发现SQL语句执行结果跟不加索 ...

  7. 保护您的Web应用:使用雷池(SafeLine)WAF的入门指南

    雷池(SafeLine)是长亭科技耗时近 10 年倾情打造的 WAF,核心检测能力由智能语义分析算法驱动.旨在提供卓越的安全保护.本文将带您一步步了解如何安装.配置和测试SafeLine,以及如何利用 ...

  8. .NET6 个人博客-推荐文章加载优化

    个人博客-推荐文章加载优化 前言 随着博客文章越来越多,那么推荐的文章也是越来越多,之前推荐文章是只推荐8篇,但是我感觉有点少,然后也是决定加一个加载按钮,也是类似与分页的效果,点击按钮可以继续加载8 ...

  9. 七牛云 + PicGo

    下载PicGo https://github.com/Molunerfinn/PicGo/releases/tag/v2.3.1 七牛云配置 1.AccessKey和SecretKey:可以在七牛云控 ...

  10. Vue2 整理(一):基础篇

    前言 首先说明:要直接上手简单得很,看官网熟悉大概有哪些东西.怎么用的,然后简单练一下就可以做出程序来了,最多两天,无论Vue2还是Vue3,就都完全可以了,Vue3就是比Vue2多了一些东西而已,所 ...