关键代码:
tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
我的demo:
def get_model(width, height, classes=40):
# TODO, modify model
network = input_data(shape=[None, width, height, 3]) # if RGB, 224,224,3
# Residual blocks
# 32 layers: n=5, 56 layers: n=9, 110 layers: n=18
n = 2
net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, n, 16)
net = tflearn.residual_block(net, 1, 32, downsample=True)
net = tflearn.residual_block(net, n-1, 32)
net = tflearn.residual_block(net, 1, 64, downsample=True)
net = tflearn.residual_block(net, n-1, 64)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, classes, activation='softmax')
#mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
net = tflearn.regression(net, optimizer=mom,
loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
max_checkpoints=10, tensorboard_verbose=0,
clip_gradients=0.)
return model def main():
trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
#trainX = trainX.reshape([-1, width, height, 1])
#testX = testX.reshape([-1, width, height, 1])
print("sample data:")
print(trainX[0])
print(trainY[0])
print(testX[-1])
print(testY[-1]) model = get_model(width, height, classes=3755) filename = 'tflearn_resnet/model.tflearn'
# try to load model and resume training
try:
#model.load(filename)
model.load("model_resnet_cifar10-195804")
print("Model loaded OK. Resume training!")
except:
pass early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
try:
model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
except StopIteration as e:
print("OK, stop iterate!Good!") model.save(filename) del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
filename = 'tflearn_resnet/model-infer.tflearn'
model.save(filename)

tflearn 在每一个epoch完毕保存模型的更多相关文章

  1. pytorch加载和保存模型

    在模型完成训练后,我们需要将训练好的模型保存为一个文件供测试使用,或者因为一些原因我们需要继续之前的状态训练之前保存的模型,那么如何在PyTorch中保存和恢复模型呢? 方法一(推荐): 第一种方法也 ...

  2. pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件

    本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取. 假设网络为model = Net(), optimizer = optim.Adam(model.parameter ...

  3. Socket编程模型之完毕port模型

    转载请注明来源:viewmode=contents">http://blog.csdn.net/caoshiying?viewmode=contents 一.回想重叠IO模型 用完毕例 ...

  4. ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档]

    ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...

  5. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  6. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  7. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  8. (原)tensorflow保存模型及载入保存的模型

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/7198773.html 参考网址: http://stackoverflow.com/questions ...

  9. 转sklearn保存模型

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

随机推荐

  1. iOS 振动反馈

    代码地址如下:http://www.demodashi.com/demo/12461.html 1. 常用场景 继 iPhone7/7P 实体 home 键出现后,home 键再也无法通过真实的物理按 ...

  2. linux 静态库使用经验

    在编写程序的过程中,对于一些接口往往抽象成lib库的形式,甚至有些程序只有一个主程序,其他接口的调用都是库的形式存在.较多的使用库会比较利于程序的维护,因为我们的程序都可以被其他的人使用,但是往往库的 ...

  3. 强大易用的日期和时间库 Joda Time

    Joda-Time提供了一组Java类包用于处理包括ISO8601标准在内的date和time.可以利用它把JDK Date和Calendar类完全替换掉,而且仍然能够提供很好的集成,并且它是线程安全 ...

  4. ASP.NET机制详细的管道事件流程(转)

    ASP.NET机制详细的管道事件流程 第一:浏览器向服务器发送请求. 1)浏览器向iis服务器发送请求网址的域名,根据http协议封装成请求报文,通过dns解析请求的ip地址,接着通过socket与i ...

  5. 【selenium+python】之Python Flask 开发环境搭建(Windows)

    一.先安装python以及pip 二.其次, Python的虚拟环境安装: 在github上下载https://github.com/pypa/virtualenv/tree/master  zip文 ...

  6. [转]XMPP基本概念--节(stanza)

    本文介绍在XMPP通信中最核心的三个XML节(stanza).这些节(stanza)有自己的作用和目标,通过组织不同的节(stanza),就能达到我们各种各样的通信目的. 首先我们来看一段XMPP流. ...

  7. C#泛型<T>说明

    泛型:即通过参数化类型来实现在同一份代码上操作多种数据类型.泛型编程是一种编程范式,它利用“参数化类型”将类型抽象化,从而实现更为灵活的复用. C#泛型的作用概述 C#泛型赋予了代码更强的类型安全,更 ...

  8. Touch ID和Passcode框架,Apple Watch风格的应用布局

    本文转载至 http://www.cocoachina.com/ios/20141031/10110.html 水平滚动条(artwalk) 测试环境:Xcode 6.0,iOS 8.0     VE ...

  9. tp框架知识 之(链接数据库和操作数据内容)

    框架有时会用到数据库的内容,在"ThinkPhp框架知识"的那篇随笔中提到过,现在这篇随笔详细的描述下. 一.链接数据库 (1)找到模块文件夹中的Conf文件夹,然后进行编写con ...

  10. 九度OJ 1148:Financial Management(财务管理) (平均数)

    与1141题相同. 时间限制:1 秒 内存限制:32 兆 特殊判题:否 提交:843 解决:502 题目描述: Larry graduated this year and finally has a ...