1. 为何保存神经网络

保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件中,保存神经网络只须保存权重W和偏置b,神经网络的结构下次再次定义为一样的,只需从文件中加载权重W和偏置b参数即可,无需重新训练神经网络

2. 代码实现:

from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T
import pickle

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

# set random seed
np.random.seed(100)

N = 400
feats = 784

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
w = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))
prediction = p_1 > 0.5
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1)
cost = xent.mean() + 0.01 * (w ** 2).sum()
gw, gb = T.grad(cost, [w, b])

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          updates=((w, w - learning_rate * gw), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    train(D[0], D[1])

# save model
with open('save/model.pickle', 'wb') as file:
    model = [w.get_value(), b.get_value()]
    pickle.dump(model, file)
    print(model[0][:10])
    print("accuracy:", compute_accuracy(D[1], predict(D[0])))

# load model
with open('save/model.pickle', 'rb') as file:
    model = pickle.load(file)
    w.set_value(model[0])
    b.set_value(model[1])
    print(w.get_value()[:10])
    print("accuracy:", compute_accuracy(D[1], predict(D[0])))

莫烦theano学习自修第十天【保存神经网络及加载神经网络】的更多相关文章

  1. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  2. 莫烦theano学习自修第八天【分类问题】

    1. 代码实现 from __future__ import print_function import numpy as np import theano import theano.tensor ...

  3. 莫烦theano学习自修第七天【回归结果可视化】

    1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...

  4. 莫烦theano学习自修第六天【回归】

    1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...

  5. 莫烦theano学习自修第五天【定义神经层】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  6. 莫烦theano学习自修第三天【共享变量】

    1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...

  7. 莫烦theano学习自修第二天【激励函数】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  8. 莫烦theano学习自修第一天【常量和矩阵的运算】

    1. 代码实现如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ # 导入numpy模块,因为numpy是常用的计算模块 import numpy as ...

  9. 莫烦theano学习自修第四天【激励函数】

    1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...

随机推荐

  1. Python:Day25 成员修饰符、特殊成员、反射、单例

    一.成员修饰符 共有成员 私有成员,__字段名,__方法 - 无法直接访问,只能间接访问 class Foo: def __init__(self,name,age): self.name = nam ...

  2. 【vue】在移动端使用better-scroll 实现滚动效果

    安装依赖:(c)npm install better-scroll --save 引入: import BScroll from 'better-scroll' 格式: var obj = new B ...

  3. oracle 乘积的实现方法

    with abc(col1) as ( ' from dual union all ' from dual union all ' from dual ) select col1,ln(col1),e ...

  4. Vue.js项目详解

    还是以Blog项目来讲解,最近我本人利用闲暇时间,以博客作为参考学习一些新的技术并尝试之前没有尝试过的思路来玩玩. 技术看似枯燥,但是带有一个目的来学,你会发现还是蛮有趣的. 主要实践的就是前后端分离 ...

  5. TextFormField数据处理

    重点:TextFormField这个Widget是由TextField封装而来,继承了TextField的特性:数据传递依靠:GlobalKey<FormState>(),Register ...

  6. Git中删除指定文件

    之前的博客Git基础使用教程介绍了Git这个开源分布式管理系统的一些基础操作,这篇博客,介绍下如何利用Git删除远程仓库的文件... 1.拉取远程仓库的文件到本地 git clone xxx 如果还未 ...

  7. C# - Span 全面介绍:探索 .NET 新增的重要组成部分

    假设要公开特殊化排序例程,以就地对内存数据执行操作.可能要公开需要使用数组的方法,并提供对相应 T[] 执行操作的实现.如果方法的调用方有数组,且希望对整个数组进行排序,这样做就非常合适.但如果调用方 ...

  8. vi十六进制编辑

    指定行:n 光标行之前或之后的n个字符nl 之后 2l 光标位置两个字符后nh 之前 2h 光标位置两个字符前 光标行之上或之下的n个字符nk 之上 1k 光标位置1个字符之上nj 之下 1j 光标位 ...

  9. JVM参数配置 java内存区域

    java内存区域 一些基本概念 http://www.importnew.com/18694.html https://www.cnblogs.com/wangyayun/p/6557851.html ...

  10. 朱晔的互联网架构实践心得S2E5:浅谈四种API设计风格(RPC、REST、GraphQL、服务端驱动)

    Web API设计其实是一个挺重要的设计话题,许多公司都会有公司层面的Web API设计规范,几乎所有的项目在详细设计阶段都会进行API设计,项目开发后都会有一份API文档供测试和联调.本文尝试根据自 ...