莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络
保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件中,保存神经网络只须保存权重W和偏置b,神经网络的结构下次再次定义为一样的,只需从文件中加载权重W和偏置b参数即可,无需重新训练神经网络
2. 代码实现:
from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T
import pickle
def compute_accuracy(y_target, y_predict):
correct_prediction = np.equal(y_predict, y_target)
accuracy = np.sum(correct_prediction)/len(correct_prediction)
return accuracy
rng = np.random
# set random seed
np.random.seed(100)
N = 400
feats = 784
# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))
# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")
# initialize the weights and biases
w = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")
# Construct Theano expression graph
p_1 = 1 / (1 + T.exp(-T.dot(x, w) - b))
prediction = p_1 > 0.5
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1)
cost = xent.mean() + 0.01 * (w ** 2).sum()
gw, gb = T.grad(cost, [w, b])
# Compile
learning_rate = 0.1
train = theano.function(
inputs=[x, y],
updates=((w, w - learning_rate * gw), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)
# Training
for i in range(500):
train(D[0], D[1])
# save model
with open('save/model.pickle', 'wb') as file:
model = [w.get_value(), b.get_value()]
pickle.dump(model, file)
print(model[0][:10])
print("accuracy:", compute_accuracy(D[1], predict(D[0])))
# load model
with open('save/model.pickle', 'rb') as file:
model = pickle.load(file)
w.set_value(model[0])
b.set_value(model[1])
print(w.get_value()[:10])
print("accuracy:", compute_accuracy(D[1], predict(D[0])))
莫烦theano学习自修第十天【保存神经网络及加载神经网络】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦theano学习自修第八天【分类问题】
1. 代码实现 from __future__ import print_function import numpy as np import theano import theano.tensor ...
- 莫烦theano学习自修第七天【回归结果可视化】
1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第一天【常量和矩阵的运算】
1. 代码实现如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ # 导入numpy模块,因为numpy是常用的计算模块 import numpy as ...
- 莫烦theano学习自修第四天【激励函数】
1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...
随机推荐
- unique
作用: 元素去重,即“删除”序列中所有相邻的重复元素(只保留一个) (此处的删除,并不是真的删除,而是指重复元素的位置被不重复的元素占领了) (其实就是把多余的元素放到了最后面) 由于它“删除”的是相 ...
- Leetcode 1. Two Sum (Python)
refer to https://blog.csdn.net/linfeng886/article/details/79772348 Description Given an array of int ...
- nova系列一:虚拟化介绍
一 什么是虚拟化 虚拟化说白了就是本来是一个完整的资源,切分或者说虚拟成多份,让这多份资源都使用起来,物尽其用,减少了浪费,提高了利用率,省了钱. 虚拟化(Virtualization)技术最早出现在 ...
- Intellij IDEA创建的Web项目配置Tomcat并启动Maven项目
本篇博客讲解IDEA如何配置Tomcat. 大部分是直接上图哦. 点击如图所示的地方,进行添加Tomcat配置页面 弹出页面后,按照如图顺序找到,点击+号 tomcat Service -> L ...
- 设置placeholder无效解决办法
一.设置placeholder的方法 placeholder属性用来设置控件内部的提示信息 <input type="text" placeholder="请输入用 ...
- 网络拓扑自动发掘之三层设备惯用的SNMP OID的含义
原文地址:https://blog.csdn.net/maty_wang/article/details/81305070 1. ipNetToMediaIfIndex Name/OID: ipNet ...
- 第四次oo博客
论述测试与正确性论证的效果差异 单元测试利用测试者构造的测试用例来检查类或方法的正确性,一般来说所需要测试的用例是无穷多的,通过人为构造代表性的测试用例来尽量测试所有代码.测试的优点在于不易出错,只要 ...
- Jenkins-job之间依赖关系配置
使用场景: 想要在某APP打新包之后,立即执行自动化测试的job来验证该新包. 比如Job A 执行完执行Job B ,如下图所示,如何建立依赖呢? 1.配置上游依赖 构建触发器-配置如下信息: 选择 ...
- Linux进程与线程的区别
进程与线程的区别,早已经成为了经典问题.自线程概念诞生起,关于这个问题的讨论就没有停止过.无论是初级程序员,还是资深专家,都应该考虑过这个问题,只是层次角度不同罢了.一般程序员而言,搞清楚二者的概念, ...
- 使用 Travis CI 实现项目的持续测试反馈
[篇幅较长,10.15前补充完毕,如希望探索可直接移步Github仓库:https://github.com/SivilTaram/CITest] 在编程课中,我们可以使用成熟的在线评测系统来测试某个 ...