mxnet 线性模型

li {list-style-type:decimal;}ol.wiz-list-level2 > li {list-style-type:lower-latin;}ol.wiz-list-level3 > li {list-style-type:lower-roman;}blockquote {padding:0 12px;padding:0 0.75rem;}blockquote > :first-child {margin-top:0;}blockquote > :last-child {margin-bottom:0;}img {border:0;max-width:100%;height:auto !important;margin:2px 0;}table {border-collapse:collapse;border:1px solid #bbbbbb;}td, th {padding:4px 8px;border-collapse:collapse;border:1px solid #bbbbbb;min-height:28px;word-break:break-all;box-sizing: border-box;}.wiz-hide {display:none !important;}
-->
span::selection, .CodeMirror-line > span > span::selection { background: #d7d4f0; }.CodeMirror-line::-moz-selection, .CodeMirror-line > span::-moz-selection, .CodeMirror-line > span > span::-moz-selection { background: #d7d4f0; }.cm-searching {background: #ffa; background: rgba(255, 255, 0, .4);}.cm-force-border { padding-right: .1px; }@media print { .CodeMirror div.CodeMirror-cursors {visibility: hidden;}}.cm-tab-wrap-hack:after { content: ""; }span.CodeMirror-selectedtext { background: none; }.CodeMirror-activeline-background, .CodeMirror-selected {transition: visibility 0ms 100ms;}.CodeMirror-blur .CodeMirror-activeline-background, .CodeMirror-blur .CodeMirror-selected {visibility:hidden;}.CodeMirror-blur .CodeMirror-matchingbracket {color:inherit !important;outline:none !important;text-decoration:none !important;}
-->
span::selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-line::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span::-moz-selection, .cm-s-tomorrow-night-eighties .CodeMirror-line > span > span::-moz-selection { background: rgba(45, 45, 45, 0.99); }.cm-s-tomorrow-night-eighties .CodeMirror-gutters { background: #000000; border-right: 0px; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker { color: #f2777a; }.cm-s-tomorrow-night-eighties .CodeMirror-guttermarker-subtle { color: #777; }.cm-s-tomorrow-night-eighties .CodeMirror-linenumber { color: #515151; }.cm-s-tomorrow-night-eighties .CodeMirror-cursor { border-left: 1px solid #6A6A6A; }.cm-s-tomorrow-night-eighties span.cm-comment { color: #d27b53; }.cm-s-tomorrow-night-eighties span.cm-atom { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-number { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-property, .cm-s-tomorrow-night-eighties span.cm-attribute { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-keyword { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-string { color: #ffcc66; }.cm-s-tomorrow-night-eighties span.cm-variable { color: #99cc99; }.cm-s-tomorrow-night-eighties span.cm-variable-2 { color: #6699cc; }.cm-s-tomorrow-night-eighties span.cm-def { color: #f99157; }.cm-s-tomorrow-night-eighties span.cm-bracket { color: #CCCCCC; }.cm-s-tomorrow-night-eighties span.cm-tag { color: #f2777a; }.cm-s-tomorrow-night-eighties span.cm-link { color: #a16a94; }.cm-s-tomorrow-night-eighties span.cm-error { background: #f2777a; color: #6A6A6A; }.cm-s-tomorrow-night-eighties .CodeMirror-activeline-background { background: #343600; }.cm-s-tomorrow-night-eighties .CodeMirror-matchingbracket { text-decoration: underline; color: white !important; }
-->

import mxnet
import mxnet.ndarray as nd
from mxnet import gluon
from mxnet import autograd

# create data

def set_data(true_w, true_b, num_examples, *args, **kwargs):
num_inputs = len(true_w)
X = nd.random_normal(shape=(num_examples, num_inputs))
y = 0
for num in range(num_inputs):
# print(num)
y += true_w[num] * X[:, num]
y += true_b
y += 0.1 * nd.random_normal(shape=y.shape)
return X, y

# create data loader
def data_loader(batch_size, X, y, shuffle=False):
data_set = gluon.data.ArrayDataset(X, y)
data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
return data_iter

# create net
def set_net(node_num):
net = gluon.nn.Sequential()
net.add(gluon.nn.Dense(node_num))
net.initialize()
return net

# create trainer
def trainer(net, loss_method, learning_rate):
trainer = gluon.Trainer(
net.collect_params(), loss_method, {'learning_rate': learning_rate}
)
return trainer

square_loss = gluon.loss.L2Loss()

# start train
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
for e in range(epochs):
total_loss = 0
for data, label in data_iter:
with autograd.record():
output = net(data)
loss = loss_method(output, label)
loss.backward()
trainer.step(batch_size)
total_loss += nd.sum(loss).asscalar()
print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
dense = net[0]

print(dense.weight.data())
print(dense.bias.data())
return dense.weight.data(), dense.bias.data()

true_w = [5, 8, 6]
true_b = 6
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
net = set_net(1)
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
num_examples=1000)
<wiz_code_mirror>

 
 
 
74
def data_loader(batch_size, X, y, shuffle=False):
 
 
 
 
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5

6

7
# create data
8

9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
    num_inputs = len(true_w)
11
    X = nd.random_normal(shape=(num_examples, num_inputs))
12
    y = 0
13
    for num in range(num_inputs):
14
        # print(num)
15
        y += true_w[num] * X[:, num]
16
    y += true_b
17
    y += 0.1 * nd.random_normal(shape=y.shape)
18
    return X, y
19

20

21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
    data_set = gluon.data.ArrayDataset(X, y)
24
    data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
    return data_iter
26

27

28
# create net
29
def set_net(node_num):
30
    net = gluon.nn.Sequential()
31
    net.add(gluon.nn.Dense(node_num))
32
    net.initialize()
33
    return net
34

35

36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
    trainer = gluon.Trainer(
39
        net.collect_params(), loss_method, {'learning_rate': learning_rate}
40
    )
41
    return trainer
42

43

44
square_loss = gluon.loss.L2Loss()
45

46

47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
    for e in range(epochs):
50
        total_loss = 0
51
        for data, label in data_iter:
52
            with autograd.record():
53
                output = net(data)
54
                loss = loss_method(output, label)
55
            loss.backward()
56
            trainer.step(batch_size)
57
            total_loss += nd.sum(loss).asscalar()
58
        print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
    dense = net[0]
60

61
    print(dense.weight.data())
62
    print(dense.bias.data())
63
    return dense.weight.data(), dense.bias.data()
64

65

66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
            num_examples=1000)
74

 
 

mxnet 线性模型的更多相关文章

  1. MXNET:监督学习

    线性回归 给定一个数据点集合 X 和对应的目标值 y,线性模型的目标就是找到一条使用向量 w 和位移 b 描述的线,来尽可能地近似每个样本X[i] 和 y[i]. 数学公式表示为\(\hat{y}=X ...

  2. 分布式机器学习框架:MxNet 前言

           原文连接:MxNet和Caffe之间有什么优缺点一.前言: Minerva: 高效灵活的并行深度学习引擎 不同于cxxnet追求极致速度和易用性,Minerva则提供了一个高效灵活的平台 ...

  3. 广义线性模型(Generalized Linear Models)

    前面的文章已经介绍了一个回归和一个分类的例子.在逻辑回归模型中我们假设: 在分类问题中我们假设: 他们都是广义线性模型中的一个例子,在理解广义线性模型之前需要先理解指数分布族. 指数分布族(The E ...

  4. ubantu16.04+mxnet +opencv+cuda8.0 环境搭建

    ubantu16.04+mxnet +opencv+cuda8.0 环境搭建 建议:环境搭建完成之后,不要更新系统(内核) 转载请注明出处: 微微苏荷 一 我的安装环境 系统:ubuntu16.04 ...

  5. MXNet设计和实现简介

    原文:https://github.com/dmlc/mxnet/issues/797 神经网络本质上是一种语言,我们通过它来表达对应用问题的理解.例如我们用卷积层来表达空间相关性,RNN来表达时间连 ...

  6. MXNET手写体识别的例子

    安装完MXNet之后,运行了官网的手写体识别的例子,这个相当于深度学习的Hello world了吧.. http://mxnet.io/tutorials/python/mnist.html 运行的过 ...

  7. MXNET安装过程中遇到libinfo导入不了的问题解决

    今天尝试安装windows版本的MXNET,在按照官网的运行了python的setup之后,import mxnet时出现如下错误:cannot import name libinfo,在网上查找发现 ...

  8. MXNet学习~试用卷积~跑CIFAR-10

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  9. MXNet学习~第一个例子~跑MNIST

    反正基本上是给自己看的,直接贴写过注释后的代码,可能有的地方理解不对,你多担待,看到了也提出来(基本上对未来的自己说的),三层跑到了97%,毕竟是第一个例子,主要就是用来理解MXNet怎么使用. #导 ...

随机推荐

  1. Linux IO 监控与深入分析

    https://jaminzhang.github.io/os/Linux-IO-Monitoring-and-Deep-Analysis/ Linux IO 监控与深入分析 引言 接昨天电话面试,面 ...

  2. Windows2008 R2上完全卸载Oracle操作步骤(转)

    最近现场项目,碰到了好几次oracle数据库被损坏,而且无法恢复的问题,没办法,只好卸载重装了.oracle卸载确实麻烦,都是从网上查的方法, 为了方便以后查询,在此就做一下记录. Windows20 ...

  3. ubuntu 源更新(sources.list)

    首先备份源列表: sudo cp /etc/apt/sources.list /etc/apt/sources.list_backup 而后用gedit或其他编辑器打开(也可以复制到Windows下打 ...

  4. erlang的tcp服务器模板

    改来改去,最后放github了,贴的也累,蛋疼 还有一个tcp批量客户端的,也一起了 大概思路是 混合模式 使用erlang:send_after添加recv的超时处理 send在socket的opt ...

  5. php排序集合

    如果你已经使用了一段时间PHP的话,那么,你应该已经对它的数组比较熟悉了——这种数据结构允许你在单个变量中存储多个值,并且可以把它们作为一个集合进行操作. 经常,开发人员发现在PHP中使用这种数据结构 ...

  6. QQ市场总监分享:黏住90后的独门攻略

    转自:http://www.gameres.com/476003.html 90后的关键词 1. 品质生活 90后是怎么样的一群人?他们注重生活的品质. 他们比我们更爱享受,或者说他们不像我们一样认为 ...

  7. [置顶] sscanf() - 从一个字符串中读进与指定格式相符的数据

    在做一道九度上机题时,突然发现sscanf()函数非常有用,就顺便从网上搜集资料整理一下. sscanf() 的作用:从一个字符串中读进与指定格式相符的数据. 原型: int sscanf (cons ...

  8. php SqlServer 中文汉字乱码

    php SqlServer 中文汉字乱码,用iconv函数转换 查询显示的时候,从GB转换为UTF8 <?php echo iconv('GB2312','UTF-8',$row['Name'] ...

  9. Android独立交叉编译环境搭建

    我们经常需将一些C/C++源码编译成本地二进制,直接在android的linux内核上运行,这是就需要进行交叉编译.由于Android的运行环境核普通Linux又区别,所以常规方式针对ARM进行交叉编 ...

  10. Centos 6.5 python 2.6.6 升级到 2.7

    1.查看python的版本 [root@localhost ~]# python -V Python 2.6.6 2.安装python 2.7.3 [root@localhost ~]# yum in ...