from:https://stackoverflow.com/questions/41616292/how-to-load-and-retrain-tflean-model

This is to create a graph and save it

graph1 = tf.Graph()
with graph1.as_default():
network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
network = merge([branch1, branch2, branch3], mode='concat', axis=1)
network = tf.expand_dims(network, 2)
network = global_max_pool(network)
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
model = tflearn.DNN(network, tensorboard_verbose=0)
clf, acc, roc_auc,fpr,tpr =classify_DNN(data,clas,model)
clf.save(model_path)

To reload and retrain or use it for prediction

MODEL = None
with tf.Graph().as_default():
## Building deep neural network
network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
network = merge([branch1, branch2, branch3], mode='concat', axis=1)
network = tf.expand_dims(network, 2)
network = global_max_pool(network)
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
new_model = tflearn.DNN(network, tensorboard_verbose=3)
new_model.load(model_path)
MODEL = new_model

Use the MODEL for prediction or retraining. The 1st line and the with loop was important. For anyone who might need help

官方例子:

""" An example showing how to save/restore models and retrieve weights. """

from __future__ import absolute_import, division, print_function

import tflearn

import tflearn.datasets.mnist as mnist

# MNIST Data
X, Y, testX, testY = mnist.load_data(one_hot=True) # Model
input_layer = tflearn.input_data(shape=[None, 784], name='input')
dense1 = tflearn.fully_connected(input_layer, 128, name='dense1')
dense2 = tflearn.fully_connected(dense1, 256, name='dense2')
softmax = tflearn.fully_connected(dense2, 10, activation='softmax')
regression = tflearn.regression(softmax, optimizer='adam',
learning_rate=0.001,
loss='categorical_crossentropy') # Define classifier, with model checkpoint (autosave)
model = tflearn.DNN(regression, checkpoint_path='model.tfl.ckpt') # Train model, with model checkpoint every epoch and every 200 training steps.
model.fit(X, Y, n_epoch=1,
validation_set=(testX, testY),
show_metric=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
snapshot_step=500, # Snapshot (save & evalaute) model every 500 steps.
run_id='model_and_weights') # ---------------------
# Save and load a model
# --------------------- # Manually save model
model.save("model.tfl") # Load a model
model.load("model.tfl") # Or Load a model from auto-generated checkpoint
# >> model.load("model.tfl.ckpt-500") # Resume training
model.fit(X, Y, n_epoch=1,
validation_set=(testX, testY),
show_metric=True,
snapshot_epoch=True,
run_id='model_and_weights') # ------------------
# Retrieving weights
# ------------------ # Retrieve a layer weights, by layer name:
dense1_vars = tflearn.variables.get_layer_variables_by_name('dense1')
# Get a variable's value, using model `get_weights` method:
print("Dense1 layer weights:")
print(model.get_weights(dense1_vars[0]))
# Or using generic tflearn function:
print("Dense1 layer biases:")
with model.session.as_default():
print(tflearn.variables.get_value(dense1_vars[1])) # It is also possible to retrieve a layer weights through its attributes `W`
# and `b` (if available).
# Get variable's value, using model `get_weights` method:
print("Dense2 layer weights:")
print(model.get_weights(dense2.W))
# Or using generic tflearn function:
print("Dense2 layer biases:")
with model.session.as_default():
print(tflearn.variables.get_value(dense2.b))

tflearn 保存模型重新训练的更多相关文章

  1. tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用——模型层次太深,或者太复杂训练时候都不会收敛

    tflearn 中文汉字识别,训练后模型存为pb给TensorFlow使用. 数据目录在data,data下放了汉字识别图片: data$ ls0  1  10  11  12  13  14  15 ...

  2. tensorflow训练自己的数据集实现CNN图像分类2(保存模型&测试单张图片)

    神经网络训练的时候,我们需要将模型保存下来,方便后面继续训练或者用训练好的模型进行测试.因此,我们需要创建一个saver保存模型. def run_training(): data_dir = 'C: ...

  3. 将tflearn的模型保存为pb,给TensorFlow使用

    参考:https://github.com/tflearn/tflearn/issues/964 解决方法: """ Tensorflow graph freezer C ...

  4. Keras保存模型并载入模型继续训练

    我们以MNIST手写数字识别为例 import numpy as np from keras.datasets import mnist from keras.utils import np_util ...

  5. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  6. pytorch加载和保存模型

    在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢? 方法一(推荐): 第一种方法也 ...

  7. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  8. (原)tensorflow保存模型及载入保存的模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7198773.html 参考网址: http://stackoverflow.com/questions ...

  9. 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

随机推荐

  1. php代码中注释的含义

    最近在梳理和优化手上的项目代码,这个项目已经走过好几任了,每一任的开发人员多多少少都有一些差异和各自的习惯,所以代码逻辑和写法上都有点[乱]. 在代码中,注释是一个非常重要的信息,更何况是接手其他人的 ...

  2. double salary = wage = 9999.99错误

    在看书时,有这么一句表达式 double salary = wage = 9999.99; 在Linux中编译时,不能通过,提示是 error: 'wage' was not declared in ...

  3. Linux系统下设置vi编辑器,tab键为4

    1.cd ~ 2.vi .exrc 3.set tabstop=4(保存并退出)即可

  4. C++解决大数组问题

    今天写一个C++小程序,竟然出现:"VS 未经处理的异常: 0xC00000FD: Stack overflow" 查了一下,普通数组变量是在堆栈中保存的,而堆栈空间有限,故出此错 ...

  5. 窗口类WNDCLASSEX名词解析

    窗口类WNDCLASSEX名词解析 typedef struct tagWNDCLASSEX{ UINT cbsize; UINT style; WNDPROC lpfnWNDProc; int cb ...

  6. PCB中贴片元器件的引脚规范(allegro)

    表贴的芯片一个引脚焊盘的宽度: 当芯片引脚间的间距>=26mil时,计算公式是(脚宽度+8mil) 当芯片引脚的间距<26mil时,计算公式是(引脚间距/2+1) 表贴的芯片一个引脚焊盘的 ...

  7. STM32F407 GPIO原理 个人笔记

    datasheet(STM32F407ZGT6.pdf)中,IO structure 为FT,表示容忍5V电压 后面的uart1_TX之类,表示端口复用 共有A~G7组IO口, 每组16个IO口:0~ ...

  8. 九度oj 题目1490:字符串链接

    题目1490:字符串链接 时间限制:1 秒 内存限制:128 兆 特殊判题:否 提交:2610 解决:1321 题目描述: 不用strcat 函数,自己编写一个字符串链接函数MyStrcat(char ...

  9. 全文检索(AB-1)-相关领域

    信息检索 认知科学 管理科学

  10. 百度音乐免费API接口

    音乐分类: 1.新歌榜,2.热歌榜,11.摇滚榜,12.爵士,16.流行21.欧美金曲榜,22.经典老歌榜,23.情歌对唱榜,24.影视金曲榜,25.网络歌曲榜 说明:百度music web版全接口h ...