mxnet 线性模型

li {list-style-type:decimal;}ol.wiz-list-level2 > li {list-style-type:lower-latin;}ol.wiz-list-level3 > li {list-style-type:lower-roman;}blockquote {padding:0 12px;padding:0 0.75rem;}blockquote > :first-child {margin-top:0;}blockquote > :last-child {margin-bottom:0;}img {border:0;max-width:100%;height:auto !important;margin:2px 0;}table {border-collapse:collapse;border:1px solid #bbbbbb;}td, th {padding:4px 8px;border-collapse:collapse;border:1px solid #bbbbbb;min-height:28px;word-break:break-all;box-sizing: border-box;}.wiz-hide {display:none !important;}
-->
span::selection, .CodeMirror-line > span > span::selection { background: #d7d4f0; }.CodeMirror-line::-moz-selection, .CodeMirror-line > span::-moz-selection, .CodeMirror-line > span > span::-moz-selection { background: #d7d4f0; }.cm-searching {background: #ffa; background: rgba(255, 255, 0, .4);}.cm-force-border { padding-right: .1px; }@media print { .CodeMirror div.CodeMirror-cursors {visibility: hidden;}}.cm-tab-wrap-hack:after { content: ""; }span.CodeMirror-selectedtext { background: none; }.CodeMirror-activeline-background, .CodeMirror-selected {transition: visibility 0ms 100ms;}.CodeMirror-blur .CodeMirror-activeline-background, .CodeMirror-blur .CodeMirror-selected {visibility:hidden;}.CodeMirror-blur .CodeMirror-matchingbracket {color:inherit !important;outline:none !important;text-decoration:none !important;}
-->
span::selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-line::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::-moz-selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-gutters { background: #000000; border-right: 0px; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker { color: #f2777a; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker-subtle { color: #777; }.cm-s-tomorrow-night-eighties .CodeMirror-linenumber { color: #515151; }.cm-s-tomorrow-night-eighties .CodeMirror-cursor { border-left: 1px solid #6A6A6A; }.cm-s-tomorrow-night-eighties span.cm-comment { color: #d27b53; }.cm-s-tomorrow-night-eighties span.cm-atom { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-number { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-property, .cm-s-tomorrow-night-eighties span.cm-attribute { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-keyword { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-string { color: #ffcc66; }.cm-s-tomorrow-night-eighties span.cm-variable { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-variable-2 { color: #6699cc; }.cm-s-tomorrow-night-eighties span.cm-def { color: #f99157; }.cm-s-tomorrow-night-eighties span.cm-bracket { color: #CCCCCC; }.cm-s-tomorrow-night-eighties span.cm-tag { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-link { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-error { background: #f2777a; color: #6A6A6A; }.cm-s-tomorrow-night-eighties .CodeMirror-activeline-background { background: #343600; }.cm-s-tomorrow-night-eighties .CodeMirror-matchingbracket { text-decoration: underline; color: white !important; }
-->

import mxnet
import mxnet.ndarray as nd
from mxnet import gluon
from mxnet import autograd

# create data

def set_data(true_w, true_b, num_examples, *args, **kwargs):
num_inputs = len(true_w)
X = nd.random_normal(shape=(num_examples, num_inputs))
y = 0
for num in range(num_inputs):
# print(num)
y += true_w[num] * X[:, num]
y += true_b
y += 0.1 * nd.random_normal(shape=y.shape)
return X, y

# create data loader
def data_loader(batch_size, X, y, shuffle=False):
data_set = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
return data_iter

# create net
def set_net(node_num):
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(node_num))
net.initialize()
return net

# create trainer
def trainer(net, loss_method, learning_rate):
trainer = gluon.Trainer(
net.collect_params(), loss_method, {'learning_rate': learning_rate}
)
return trainer

square_loss = gluon.loss.L2Loss()

# start train
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
for e in range(epochs):
total_loss = 0
for data, label in data_iter:
with autograd.record():
output = net(data)
loss = loss_method(output, label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
dense = net[0]

print(dense.weight.data())
print(dense.bias.data())
return dense.weight.data(), dense.bias.data()

true_w = [5, 8, 6]
true_b = 6
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
net = set_net(1)
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
num_examples=1000)
<wiz_code_mirror>

 
 
 
74
def data_loader(batch_size, X, y, shuffle=False):
 
 
 
 
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5

6

7
# create data
8

9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
    num_inputs = len(true_w)
11
    X = nd.random_normal(shape=(num_examples, num_inputs))
12
    y = 0
13
    for num in range(num_inputs):
14
        # print(num)
15
        y += true_w[num] * X[:, num]
16
    y += true_b
17
    y += 0.1 * nd.random_normal(shape=y.shape)
18
    return X, y
19

20

21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
    data_set = gluon.data.ArrayDataset(X, y)
24
    data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
    return data_iter
26

27

28
# create net
29
def set_net(node_num):
30
    net = gluon.nn.Sequential()
31
    net.add(gluon.nn.Dense(node_num))
32
    net.initialize()
33
    return net
34

35

36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
    trainer = gluon.Trainer(
39
        net.collect_params(), loss_method, {'learning_rate': learning_rate}
40
    )
41
    return trainer
42

43

44
square_loss = gluon.loss.L2Loss()
45

46

47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
    for e in range(epochs):
50
        total_loss = 0
51
        for data, label in data_iter:
52
            with autograd.record():
53
                output = net(data)
54
                loss = loss_method(output, label)
55
            loss.backward()
56
            trainer.step(batch_size)
57
            total_loss += nd.sum(loss).asscalar()
58
        print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
    dense = net[0]
60

61
    print(dense.weight.data())
62
    print(dense.bias.data())
63
    return dense.weight.data(), dense.bias.data()
64

65

66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
            num_examples=1000)
74

 
 

mxnet 线性模型的更多相关文章

  1. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  2. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  3. 广义线性模型(Generalized Linear Models)

    前面的文章已经介绍了一个回归和一个分类的例子.在逻辑回归模型中我们假设: 在分类问题中我们假设: 他们都是广义线性模型中的一个例子,在理解广义线性模型之前需要先理解指数分布族. 指数分布族(The E ...

  4. ubantu16.04+mxnet +opencv+cuda8.0 环境搭建

    ubantu16.04+mxnet +opencv+cuda8.0 环境搭建 建议:环境搭建完成之后,不要更新系统(内核) 转载请注明出处: 微微苏荷 一 我的安装环境 系统:ubuntu16.04 ...

  5. MXNet设计和实现简介

    原文:https://github.com/dmlc/mxnet/issues/797 神经网络本质上是一种语言,我们通过它来表达对应用问题的理解.例如我们用卷积层来表达空间相关性,RNN来表达时间连 ...

  6. MXNET手写体识别的例子

    安装完MXNet之后,运行了官网的手写体识别的例子,这个相当于深度学习的Hello world了吧.. http://mxnet.io/tutorials/python/mnist.html 运行的过 ...

  7. MXNET安装过程中遇到libinfo导入不了的问题解决

    今天尝试安装windows版本的MXNET,在按照官网的运行了python的setup之后,import mxnet时出现如下错误:cannot import name libinfo,在网上查找发现 ...

  8. MXNet学习~试用卷积~跑CIFAR-10

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  9. MXNet学习~第一个例子~跑MNIST

    反正基本上是给自己看的,直接贴写过注释后的代码,可能有的地方理解不对,你多担待,看到了也提出来(基本上对未来的自己说的),三层跑到了97%,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导 ...

随机推荐

  1. 基于SQL调用Com组件来发送邮件

    这个需求是公司有个文控中心,如果有用增删改了文件信息希望可以发邮件通知到有权限的人.当然方式很多. 这里是用数据库作业来完成 JOB+Com,这里用的com组件是Jmail 当然你需要把com组件放到 ...

  2. 苹果手机 iTunes 资料备份到另一手机

    百度教程 https://jingyan.baidu.com/article/d621e8da332e602865913f8e.html 直接使用iTunes将老手机的资料备份, (可能需要关闭手机定 ...

  3. nginx upstream配置

    upstream *.com { server 127.0.0.1:5000 weight=10 max_fails=2 fail_timeout=30s;} server { listen 80; ...

  4. ffmpeg超详细综合教程——摄像头直播

    本文的示例将实现:读取PC摄像头视频数据并以RTMP协议发送为直播流.示例包含了1.ffmpeg的libavdevice的使用2.视频解码.编码.推流的基本流程具有较强的综合性.要使用libavdev ...

  5. Java-Maven-Runoob:Maven 引入外部依赖

    ylbtech-Java-Maven-Runoob:Maven 引入外部依赖 1.返回顶部 1. Maven 引入外部依赖 如果我们需要引入第三库文件到项目,该怎么操作呢? pom.xml 的 dep ...

  6. AngularJS:Http

    ylbtech-AngularJS:Http 1.返回顶部 1. AngularJS XMLHttpRequest $http 是 AngularJS 中的一个核心服务,用于读取远程服务器的数据. 使 ...

  7. python开发面向对象基础:封装

    一,封装 [封装] 隐藏对象的属性和实现细节,仅对外提供公共访问方式. [好处] 1. 将变化隔离: 2. 便于使用: 3. 提高复用性: 4. 提高安全性: [封装原则] 1. 将不需要对外提供的内 ...

  8. 编译openwrt失败 “Please install theopenssl library”

    make menuconfig出现了错误 Build dependency: Please install theopenssl library(with development headers) P ...

  9. MVC 公共类App_Code不识别

    .Net MVC需要写公共类的时候 右击添加 App_Code 文件夹,新建类—>右击类—>属性,生成操作 —>选择 —>编译 .net MVC项目本身是个应用程序,所以其实不 ...

  10. javascript好文分享

    JavaScript精华 http://www.cnblogs.com/jesse2013/p/the-part-of-javascript-you-must-know.html JavaScript ...