莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络
保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件中,保存神经网络只须保存权重W和偏置b,神经网络的结构下次再次定义为一样的,只需从文件中加载权重W和偏置b参数即可,无需重新训练神经网络
2. 代码实现:
from __future__ import print_function import numpy as np import theano import theano.tensor as T import pickle def compute_accuracy(y_target, y_predict): correct_prediction = np.equal(y_predict, y_target) accuracy = np.sum(correct_prediction)/len(correct_prediction) return accuracy rng = np.random # set random seed np.random.seed(100) N = 400 feats = 784 # generate a dataset: D = (input_values, target_class) D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2)) # Declare Theano symbolic variables x = T.dmatrix("x") y = T.dvector("y") # initialize the weights and biases w = theano.shared(rng.randn(feats), name="w") b = theano.shared(0., name="b") # Construct Theano expression graph p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b)) prediction = p_1 > 0.5 xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) cost = xent.mean() + 0.01 * (w ** 2).sum() gw, gb = T.grad(cost, [w, b]) # Compile learning_rate = 0.1 train = theano.function( inputs=[x, y], updates=((w, w - learning_rate * gw), (b, b - learning_rate * gb))) predict = theano.function(inputs=[x], outputs=prediction) # Training for i in range(500): train(D[0], D[1]) # save model with open('save/model.pickle', 'wb') as file: model = [w.get_value(), b.get_value()] pickle.dump(model, file) print(model[0][:10]) print("accuracy:", compute_accuracy(D[1], predict(D[0]))) # load model with open('save/model.pickle', 'rb') as file: model = pickle.load(file) w.set_value(model[0]) b.set_value(model[1]) print(w.get_value()[:10]) print("accuracy:", compute_accuracy(D[1], predict(D[0])))
莫烦theano学习自修第十天【保存神经网络及加载神经网络】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦theano学习自修第八天【分类问题】
1. 代码实现 from __future__ import print_function import numpy as np import theano import theano.tensor ...
- 莫烦theano学习自修第七天【回归结果可视化】
1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第一天【常量和矩阵的运算】
1. 代码实现如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ # 导入numpy模块,因为numpy是常用的计算模块 import numpy as ...
- 莫烦theano学习自修第四天【激励函数】
1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...
随机推荐
- 第一章 mysql的体系结构与存储引擎
数据库从逻辑上可以分为两部分,一部分负责存储即文件系统,这部分有个更时髦的名字叫存储引擎,存储引擎负责如何把数据以及索引相关的内容以合适的形式组织并存储到磁盘上.另一部分为server部分,负责和用户 ...
- 008_python内置语法
一. 参考:http://www.runoob.com/python/python-built-in-functions.html (1)vars() 描述 vars() 函数返回对象object的属 ...
- Shell 文本处理三剑客之grep
grep ♦参数 -E,--extended-regexp 模式是扩展正则表达式 -i,--ignore-case 忽略大小写 -n,--line-number 打印行号 -v,--invert-ma ...
- Raid卡介绍
raid0条带卷 最少需要一块硬盘 可以把所有硬盘的容量都叠加在一起,可以拥有很高的读写速度,硬盘空间也能得到很好的利用 但是只要其中一块硬盘换了,数据就全丢失了 raid1镜像卷 最少需要两块硬盘, ...
- oracle 11G direct path read 很美也很伤人
direct path read在11g中,全表扫描可能使用direct path read方式,绕过buffer cache,这样的全表扫描就是物理读了. 在10g中,都是通过gc buffer来读 ...
- 箱线图boxplot
箱线图boxplot--展示数据的分布 图表作用: 1.反映一组数据的分布特征,如:分布是否对称,是否存在离群点 2.对多组数据的分布特征进行比较 3.如果只有一个定量变量,很少用箱线图去看数据的分布 ...
- fastJson 之 JSONPath使用
1. JSONPath介绍 官网地址: https://github.com/alibaba/fastjson/wiki/JSONPath fastjson 1.2.0之后的版本支持JSONPath. ...
- Java消息队列——JMS概述
一.什么是JMS JMS即Java消息服务(Java Message Service)应用程序接口,是一个Java平台中关于面向消息中间件(MOM)的API,用于在两个应用程序之间,或分布式系统中发送 ...
- VBS弹出来的对话框如何置顶!--果然技巧
msgbox 第二参数+4096 mshta vbscript:msgbox("提示内容6",6,"提示窗口6")(window.close)
- mongodb .explain('executionStats') 查询性能分析(转)
mongodb性能分析方法:explain() 为了演示的效果,我们先来创建一个有200万个文档的记录.(我自己的电脑耗了15分钟左右插入完成.如果你想插更多的文档也没问题,只要有耐心等就可以了.) ...