mxnet 线性模型

li {list-style-type:decimal;}ol.wiz-list-level2 > li {list-style-type:lower-latin;}ol.wiz-list-level3 > li {list-style-type:lower-roman;}blockquote {padding:0 12px;padding:0 0.75rem;}blockquote > :first-child {margin-top:0;}blockquote > :last-child {margin-bottom:0;}img {border:0;max-width:100%;height:auto !important;margin:2px 0;}table {border-collapse:collapse;border:1px solid #bbbbbb;}td, th {padding:4px 8px;border-collapse:collapse;border:1px solid #bbbbbb;min-height:28px;word-break:break-all;box-sizing: border-box;}.wiz-hide {display:none !important;}
-->
span::selection, .CodeMirror-line > span > span::selection { background: #d7d4f0; }.CodeMirror-line::-moz-selection, .CodeMirror-line > span::-moz-selection, .CodeMirror-line > span > span::-moz-selection { background: #d7d4f0; }.cm-searching {background: #ffa; background: rgba(255, 255, 0, .4);}.cm-force-border { padding-right: .1px; }@media print { .CodeMirror div.CodeMirror-cursors {visibility: hidden;}}.cm-tab-wrap-hack:after { content: ""; }span.CodeMirror-selectedtext { background: none; }.CodeMirror-activeline-background, .CodeMirror-selected {transition: visibility 0ms 100ms;}.CodeMirror-blur .CodeMirror-activeline-background, .CodeMirror-blur .CodeMirror-selected {visibility:hidden;}.CodeMirror-blur .CodeMirror-matchingbracket {color:inherit !important;outline:none !important;text-decoration:none !important;}
-->
span::selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-line::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::-moz-selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-gutters { background: #000000; border-right: 0px; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker { color: #f2777a; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker-subtle { color: #777; }.cm-s-tomorrow-night-eighties .CodeMirror-linenumber { color: #515151; }.cm-s-tomorrow-night-eighties .CodeMirror-cursor { border-left: 1px solid #6A6A6A; }.cm-s-tomorrow-night-eighties span.cm-comment { color: #d27b53; }.cm-s-tomorrow-night-eighties span.cm-atom { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-number { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-property, .cm-s-tomorrow-night-eighties span.cm-attribute { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-keyword { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-string { color: #ffcc66; }.cm-s-tomorrow-night-eighties span.cm-variable { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-variable-2 { color: #6699cc; }.cm-s-tomorrow-night-eighties span.cm-def { color: #f99157; }.cm-s-tomorrow-night-eighties span.cm-bracket { color: #CCCCCC; }.cm-s-tomorrow-night-eighties span.cm-tag { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-link { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-error { background: #f2777a; color: #6A6A6A; }.cm-s-tomorrow-night-eighties .CodeMirror-activeline-background { background: #343600; }.cm-s-tomorrow-night-eighties .CodeMirror-matchingbracket { text-decoration: underline; color: white !important; }
-->

import mxnet
import mxnet.ndarray as nd
from mxnet import gluon
from mxnet import autograd

# create data

def set_data(true_w, true_b, num_examples, *args, **kwargs):
num_inputs = len(true_w)
X = nd.random_normal(shape=(num_examples, num_inputs))
y = 0
for num in range(num_inputs):
# print(num)
y += true_w[num] * X[:, num]
y += true_b
y += 0.1 * nd.random_normal(shape=y.shape)
return X, y

# create data loader
def data_loader(batch_size, X, y, shuffle=False):
data_set = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
return data_iter

# create net
def set_net(node_num):
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(node_num))
net.initialize()
return net

# create trainer
def trainer(net, loss_method, learning_rate):
trainer = gluon.Trainer(
net.collect_params(), loss_method, {'learning_rate': learning_rate}
)
return trainer

square_loss = gluon.loss.L2Loss()

# start train
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
for e in range(epochs):
total_loss = 0
for data, label in data_iter:
with autograd.record():
output = net(data)
loss = loss_method(output, label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
dense = net[0]

print(dense.weight.data())
print(dense.bias.data())
return dense.weight.data(), dense.bias.data()

true_w = [5, 8, 6]
true_b = 6
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
net = set_net(1)
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
num_examples=1000)
<wiz_code_mirror>

 
 
 
74
def data_loader(batch_size, X, y, shuffle=False):
 
 
 
 
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5

6

7
# create data
8

9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
    num_inputs = len(true_w)
11
    X = nd.random_normal(shape=(num_examples, num_inputs))
12
    y = 0
13
    for num in range(num_inputs):
14
        # print(num)
15
        y += true_w[num] * X[:, num]
16
    y += true_b
17
    y += 0.1 * nd.random_normal(shape=y.shape)
18
    return X, y
19

20

21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
    data_set = gluon.data.ArrayDataset(X, y)
24
    data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
    return data_iter
26

27

28
# create net
29
def set_net(node_num):
30
    net = gluon.nn.Sequential()
31
    net.add(gluon.nn.Dense(node_num))
32
    net.initialize()
33
    return net
34

35

36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
    trainer = gluon.Trainer(
39
        net.collect_params(), loss_method, {'learning_rate': learning_rate}
40
    )
41
    return trainer
42

43

44
square_loss = gluon.loss.L2Loss()
45

46

47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
    for e in range(epochs):
50
        total_loss = 0
51
        for data, label in data_iter:
52
            with autograd.record():
53
                output = net(data)
54
                loss = loss_method(output, label)
55
            loss.backward()
56
            trainer.step(batch_size)
57
            total_loss += nd.sum(loss).asscalar()
58
        print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
    dense = net[0]
60

61
    print(dense.weight.data())
62
    print(dense.bias.data())
63
    return dense.weight.data(), dense.bias.data()
64

65

66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
            num_examples=1000)
74

 
 

mxnet 线性模型的更多相关文章

  1. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  2. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  3. 广义线性模型(Generalized Linear Models)

    前面的文章已经介绍了一个回归和一个分类的例子.在逻辑回归模型中我们假设: 在分类问题中我们假设: 他们都是广义线性模型中的一个例子,在理解广义线性模型之前需要先理解指数分布族. 指数分布族(The E ...

  4. ubantu16.04+mxnet +opencv+cuda8.0 环境搭建

    ubantu16.04+mxnet +opencv+cuda8.0 环境搭建 建议:环境搭建完成之后,不要更新系统(内核) 转载请注明出处: 微微苏荷 一 我的安装环境 系统:ubuntu16.04 ...

  5. MXNet设计和实现简介

    原文:https://github.com/dmlc/mxnet/issues/797 神经网络本质上是一种语言,我们通过它来表达对应用问题的理解.例如我们用卷积层来表达空间相关性,RNN来表达时间连 ...

  6. MXNET手写体识别的例子

    安装完MXNet之后,运行了官网的手写体识别的例子,这个相当于深度学习的Hello world了吧.. http://mxnet.io/tutorials/python/mnist.html 运行的过 ...

  7. MXNET安装过程中遇到libinfo导入不了的问题解决

    今天尝试安装windows版本的MXNET,在按照官网的运行了python的setup之后,import mxnet时出现如下错误:cannot import name libinfo,在网上查找发现 ...

  8. MXNet学习~试用卷积~跑CIFAR-10

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  9. MXNet学习~第一个例子~跑MNIST

    反正基本上是给自己看的,直接贴写过注释后的代码,可能有的地方理解不对,你多担待,看到了也提出来(基本上对未来的自己说的),三层跑到了97%,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导 ...

随机推荐

  1. Eclipse 中 No java virtual machine was found... 解决方法

    这个链接说的不错,http://www.mafutian.net/123.html,,但是还有一种可能是64位和32位的问题,也就是eclipse32位只能用32位的jdk,eclipse64位的只能 ...

  2. FPGA论剑(续)

    25年之后,第二次华山论剑之时,天下第一的王重阳已然仙逝,郭靖少年英杰刚过二十岁,接东邪黄药师.北丐洪七公300招不败,二人默认郭靖天下第一.南帝段智兴因为出家,法号“一灯”,早已看破名利,故没有参加 ...

  3. linux下软件的种类和对应的安装及卸载的方式

    转: 一个Linux应用程序的软件包中可以包含两种不同的内容: 1)一种就是可执行文件,也就是解开包后就可以直接运行的.在Windows中所 有的软件包都是这种类型.安装完这个程序后,你就可以使用,但 ...

  4. JCE无限制权限策略文件

    JCE无限制权限策略文件,里面是对应jdk6和jdk7的文件 官网下载地址是 JDK6:http://www.oracle.com/technetwork/java/javase/downloads/ ...

  5. WCF服务端返回:(413) Request Entity Too Large

    出现这个原因我们应该都能猜测到,文件传出过大,超出了WCF默认范围,那么我们需要进行修改. 服务端和客户端都需要修改. 第一.客户端: <system.serviceModel> < ...

  6. 免费SSL证书 - Let's Encrypt申请(WINDOWS + IIS版)

    Let’s Encrypt 项目是由互联网安全研究小组ISRG,Internet Security Research Group主导并开发的一个新型数字证书认证机构CA,Certificate Aut ...

  7. pcs与crmsh命令比较

    一.概念 1.crmsh This project is not part of the GNU Project. Pacemaker command line interface for manag ...

  8. krpano之字幕添加

    字幕是指介绍语音的字幕,字幕随着语音的播放而滚动,随语音暂停而暂停.字幕添加的前提是用之前的方法添加过介绍语音. 原理: 字幕层在溢出隐藏的父元素中向右滑动,当点击声音控制按钮时,字幕位置被固定,再次 ...

  9. html收藏

    全屏显示<input type="button" name="fullscreen" value="全屏显示" onclick=&qu ...

  10. node.js开发指南读书笔记(1)

    3.1 开始使用Node.js编程 3.1.1 Hello World 将以下源代码保存到helloworld.js文件中 console.log('Hello World!'); console.l ...