import numpy as np
 import matplotlib.pyplot as plt
 import os
 import time

 import tensorflow as tf
 tf.enable_eager_execution()

 # create data
 X = np.linspace(-1, 1, 5000)
 np.random.shuffle(X)
 y = 0.5 * X + 2 + np.random.normal(0, 0.05, (5000,))

 # plot data
 plt.scatter(X, y)
 plt.show()

 # split data
 X_train, y_train = X[:4000], y[:4000]
 X_test, y_test = X[4000:], y[4000:]

 # tf.data
 BATCH_SIZE = 32
 BUFFER_SIZE = 512
 dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).shuffle(BUFFER_SIZE)

 # subclassed model
 UNITS = 1

 class Model(tf.keras.Model):
     def __init__(self):
         super(Model, self).__init__()
         self.fc = tf.keras.layers.Dense(units=UNITS)

     def call(self, inputs):
         return self.fc(inputs)

 model = Model()

 optimizer = tf.train.AdamOptimizer()

 # loss function
 def loss_function(real, pred):
     return tf.losses.mean_squared_error(labels=real, predictions=pred)

 EPOCHS = 30
 checkpoint_dir = './save_subclassed_keras_model_training_checkpoints'
 if not os.path.exists(checkpoint_dir):
     os.makedirs(checkpoint_dir)

 # training loop
 for epoch in range(EPOCHS):
     start = time.time()
     epoch_loss = 0

     for (batch, (x, y)) in enumerate(dataset):
         x = tf.cast(x, tf.float32)
         y = tf.cast(y, tf.float32)
         x = tf.expand_dims(x, axis=1)
         y = tf.expand_dims(y, axis=1)
         # print(x)    # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
         # print(y)    # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
         with tf.GradientTape() as tape:
             predictions = model(x)
             # print(predictions)  # tf.Tensor([...], shape=(BATCH_SIZE, 1), dtype=float32)
             batch_loss = loss_function(real=y, pred=predictions)

         grads = tape.gradient(batch_loss, model.variables)
         optimizer.apply_gradients(zip(grads, model.variables),
                                   global_step=tf.train.get_or_create_global_step())
         epoch_loss += batch_loss

         if (batch + 1) % 10 == 0:
             print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                          batch + 1,
                                                          batch_loss/int(x.shape[0])))

     print('Epoch {} Loss {:.4f}'.format(epoch + 1, epoch_loss/len(X_train)))
     print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

     # save checkpoint
     checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
     if (epoch + 1) % 10 == 0:
         model.save_weights(checkpoint_prefix.format(epoch=epoch), overwrite=True)

 _model = Model()
 _model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
 _model.build(input_shape=tf.TensorShape([BATCH_SIZE, 1]))
 _model.summary()

 test_dataset = tf.data.Dataset.from_tensor_slices(X_test).batch(1)

 for (batch, x) in enumerate(test_dataset):
     x = tf.cast(x, tf.float32)
     x = tf.expand_dims(x, axis=1)
     print(x)
     predictions = _model(x)
     print(predictions)
     exit()

[Tensorflow] 使用 model.save_weights() 保存 Keras Subclassed Model的更多相关文章

  1. [Tensorflow] 使用 model.save_weights() 保存 / 加载 Keras Subclassed Model

    在 parameters.py 中,定义了各类参数. # training data directory TRAINING_DATA_DIR = './data/' # checkpoint dire ...

  2. [Tensorflow] 使用 tf.train.Checkpoint() 保存 / 加载 keras subclassed model

    在 subclassed_model.py 中,通过对 tf.keras.Model 进行子类化,设计了两个自定义模型. import tensorflow as tf tf.enable_eager ...

  3. keras系列︱Sequential与Model模型、keras基本结构功能(一)

    引自:http://blog.csdn.net/sinat_26917383/article/details/72857454 中文文档:http://keras-cn.readthedocs.io/ ...

  4. Keras(一)Sequential与Model模型、Keras基本结构功能

    keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...

  5. [Model] LeNet-5 by Keras

    典型的卷积神经网络. 数据的预处理 Keras傻瓜式读取数据:自动下载,自动解压,自动加载. # X_train: array([[[[ 0., 0., 0., ..., 0., 0., 0.], [ ...

  6. xen 保存快照的实现之 —— device model 状态保存

    xen 保存快照的实现之 —— device model 状态保存 实现要点: 设备状态保存在 /var/lib/xen/qemu-save.x 文件这个文件由 qemu-dm 产生,也由 qemu- ...

  7. AI - TensorFlow - 示例05:保存和恢复模型

    保存和恢复模型(Save and restore models) 官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_mo ...

  8. 如何保存Keras模型

    我们不推荐使用pickle或cPickle来保存Keras模型 你可以使用model.save(filepath)将Keras模型和权重保存在一个HDF5文件中,该文件将包含: 模型的结构,以便重构该 ...

  9. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

随机推荐

  1. SSH-struts2的异常处理

    在学习j2se的时候学习过利用java的exception类去处理异常.在struts2框架中也提供了对于异常的处理.简单说就是当Action发生异常时.能够在struts2.xml文件里进行配置,将 ...

  2. The method getJspApplicationContext(ServletContext) is undefined for the type JspFactory的解决方法

    An error occurred at line: [31] in the generated java file: [/data/tmisnt/work/Catalina/localhost/_/ ...

  3. MySQL-修改数据(REPLACE)

    MySQL-REPLACE语句 功能介绍:用于向数据库表插入或更新数据. REPLACE语句的工作原理: 如果给定行数据不存在,那么MySQL REPLACE语句会插入新行. 如果给定行数据存在,则R ...

  4. lydsy1013: [JSOI2008]球形空间产生器sphere 高斯消元

    题链:http://www.lydsy.com/JudgeOnline/problem.php?id=1013 1013: [JSOI2008]球形空间产生器sphere 时间限制: 1 Sec  内 ...

  5. Apple Swift学习教程

    翻译自苹果的官方文档:The Swift Programming Language. 简单介绍 今天凌晨Apple刚刚公布了Swift编程语言,本文从其公布的书籍<The Swift Progr ...

  6. 仰视源代码,实现strcmp

    //这是系统库的实现 int strcmp(const char* src, const char* dest) { int rtn = 0; while(!(rtn = *(unsigned cha ...

  7. Noip模拟 Day6.12

    第一题:贪吃蛇(snake) 本题其实就是判断一个有向图中有没有环,做一次拓扑排序就可以了,如果所有点都入队了,就表示没有环,否则就有环.或者就是dfs一次,每个点只需要被访问一次,这样也是O(n)的 ...

  8. Vijos 1565 多边形 【区间DP】

    描述 zgx给了你一个n边的多边形,这个多边形每个顶点赋予一个值,每条边都被标上运算符号+或*,对于这个多边形有一个游戏,游戏的步骤如下:(1)第一步,删掉一条边:(2)接下来n-1步,每步对剩下的边 ...

  9. C语言8大经典排序算法(1)

    算法一直是编程的基础,而排序算法是学习算法的开始,排序也是数据处理的重要内容.所谓排序是指将一个无序列整理成按非递减顺序排列的有序序列.排列的方法有很多,根据待排序序列的规模以及对数据的处理的要求,可 ...

  10. ubuntu 12.04.5 LTS版本 更新 source.list

    更新后一定要:apt-get update # # deb cdrom:[Ubuntu-Server LTS _Precise Pangolin_ - Release amd64 (20140806. ...