用CapsNets做电能质量扰动分类(2019-08-05)
当下最热神经网络为CNN,2017年10月,深度学习之父Hinton发表《胶囊间的动态路由》(Capsule Networks),最近谷歌正式开源了Hinton胶囊理论代码,提出的胶囊神经网络。本文不涉及原理,只是站在巨人的肩膀人,尝试把胶囊网络应用与分类问题。
原理和代码的参考文献是:https://blog.csdn.net/weixin_40920290/article/details/82951826
其中,本文采用的数据集和以2019年3月CNN做电能质量分类的一样,可以去那个博文中下载数据集。这里只展示代码。需要提醒的是,Capsule Networks的运行速度会比较慢,耐心等待.
如果由于格式问题无法运行,可以把邮箱私戳发给我,我把Capsule.py发给你
1.代码
from __future__ import print_function
import numpy as np
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask
import keras
from pandas import read_csv
K.set_image_data_format('channels_last')
def CapsNet(input_shape, n_class, routings):
    """
    A Capsule Network on MNIST.
    :param input_shape: data shape, 3d, [width, height, channels]
    :param n_class: number of classes
    :param routings: number of routing iterations
    :return: Two Keras Models, the first one used for training, and the second one for evaluation.
            `eval_model` can also be used for training.
    """
    x = layers.Input(shape=input_shape)
    # Layer 1: Just a conventional Conv2D layer
    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)
    # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
    primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')
    # Layer 3: Capsule layer. Routing algorithm works here.
    digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
                            name='digitcaps')(primarycaps)
    # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
    # If using tensorflow, this will not be necessary. :)
    out_caps = Length(name='capsnet')(digitcaps)
    # Decoder network.
    y = layers.Input(shape=(n_class,))
    masked_by_y = Mask()([digitcaps, y])  # The true label is used to mask the output of capsule layer. For training
    masked = Mask()(digitcaps)  # Mask using the capsule with maximal length. For prediction
    # Shared Decoder model in training and prediction
    decoder = models.Sequential(name='decoder')
    decoder.add(layers.Dense(512, activation='relu', input_dim=16*n_class))
    decoder.add(layers.Dense(1024, activation='relu'))
    decoder.add(layers.Dense(np.prod(input_shape), activation='sigmoid'))
    decoder.add(layers.Reshape(target_shape=input_shape, name='out_recon'))
    # Models for training and evaluation (prediction)
    train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
    eval_model = models.Model(x, [out_caps, decoder(masked)])
    # manipulate model
    noise = layers.Input(shape=(n_class, 16))
    noised_digitcaps = layers.Add()([digitcaps, noise])
    masked_noised_y = Mask()([noised_digitcaps, y])
    manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
    return train_model, eval_model, manipulate_model
def margin_loss(y_true, y_pred):
    """
    Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
    :param y_true: [None, n_classes]
    :param y_pred: [None, num_capsule]
    :return: a scalar loss value.
    """
    L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
        0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
    return K.mean(K.sum(L, 1))
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data
    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                              batch_size=args.batch_size, histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                          save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))
    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})
    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """
    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                          height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])
    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#
    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)
    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)
    return model
def test(model, data, args):
    x_test, y_test = data
    y_pred, x_recon = model.predict(x_test, batch_size=100)
    print('-'*30 + 'Begin: test' + '-'*30)
    print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])
    img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
    image = img * 255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
    print()
    print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
    print('-' * 30 + 'End: test' + '-' * 30)
    plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
    plt.show()
def manipulate_latent(model, data, args):
    print('-'*30 + 'Begin: manipulate' + '-'*30)
    x_test, y_test = data
    index = np.argmax(y_test, 1) == args.digit
    number = np.random.randint(low=0, high=sum(index) - 1)
    x, y = x_test[index][number], y_test[index][number]
    x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
    noise = np.zeros([1, 10, 16])
    x_recons = []
    for dim in range(16):
        for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
            tmp = np.copy(noise)
            tmp[:,:,dim] = r
            x_recon = model.predict([x, y, tmp])
            x_recons.append(x_recon)
    x_recons = np.concatenate(x_recons)
    img = combine_images(x_recons, height=16)
    image = img*255
    Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
    print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit))
    print('-' * 30 + 'End: manipulate' + '-' * 30)
def load_mnist():
    # the data, shuffled and split between train and test sets
    dataset = read_csv('ZerosOnePowerQuality.csv')
    values = dataset.values
    XY= values
    num_classes = 8
    Y = XY[:,784]
    n_train_hours1 =9000
    x_train=XY[:n_train_hours1,0:784]
    trainY =Y[:n_train_hours1]
    x_test =XY[n_train_hours1:, 0:784]
    testY =Y[n_train_hours1:]
    x_train = x_train.reshape(-1,28,28,1)
    x_test = x_test.reshape(-1,28,28,1)
    y_train = keras.utils.to_categorical(trainY, num_classes)
    y_test = keras.utils.to_categorical(testY, num_classes)
    return (x_train, y_train), (x_test, y_test)
if __name__ == "__main__":
    import os
    import argparse
    from keras.preprocessing.image import ImageDataGenerator
    from keras import callbacks
    # setting the hyper parameters
    parser = argparse.ArgumentParser(description="Capsule Network on MNIST.")
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--batch_size', default=100, type=int)
    parser.add_argument('--lr', default=0.001, type=float,
                        help="Initial learning rate")
    parser.add_argument('--lr_decay', default=0.9, type=float,
                        help="The value multiplied by lr at each epoch. Set a larger value for larger epochs")
    parser.add_argument('--lam_recon', default=0.392, type=float,
                        help="The coefficient for the loss of decoder")
    parser.add_argument('-r', '--routings', default=3, type=int,
                        help="Number of iterations used in routing algorithm. should > 0")
    parser.add_argument('--shift_fraction', default=0.1, type=float,
                        help="Fraction of pixels to shift at most in each direction.")
    parser.add_argument('--debug', action='store_true',
                        help="Save weights by TensorBoard")
    parser.add_argument('--save_dir', default='./result')
    parser.add_argument('-t', '--testing', action='store_true',
                        help="Test the trained model on testing dataset")
    parser.add_argument('--digit', default=5, type=int,
                        help="Digit to manipulate")
    parser.add_argument('-w', '--weights', default=None,
                        help="The path of the saved weights. Should be specified when testing")
    args = parser.parse_args()
    print(args)
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    # load data
    (x_train, y_train), (x_test, y_test) = load_mnist()
    # define model
    model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
                                                  n_class=len(np.unique(np.argmax(y_train, 1))),
                                                  routings=args.routings)
    model.summary()
    # train or test
    if args.weights is not None:  # init the model weights with provided one
        model.load_weights(args.weights)
    if not args.testing:
        train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)
    else:  # as long as weights are given, will run testing
        if args.weights is None:
            print('No weights are provided. Will test using random initialized weights.')
        manipulate_latent(manipulate_model, (x_test, y_test), args)
        test(model=eval_model, data=(x_test, y_test), args=args)
2.网络结构
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0                                           
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 20, 20, 256)  20992      input_1[0][0]                   
__________________________________________________________________________________________________
primarycap_conv2d (Conv2D)      (None, 6, 6, 256)    5308672    conv1[0][0]                     
__________________________________________________________________________________________________
primarycap_reshape (Reshape)    (None, 1152, 8)      0          primarycap_conv2d[0][0]         
__________________________________________________________________________________________________
primarycap_squash (Lambda)      (None, 1152, 8)      0          primarycap_reshape[0][0]       
__________________________________________________________________________________________________
digitcaps (CapsuleLayer)        (None, 8, 16)        1179648    primarycap_squash[0][0]         
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 8)            0                                           
__________________________________________________________________________________________________
mask_1 (Mask)                  (None, 128)          0          digitcaps[0][0]                 
                                                                input_2[0][0]                   
__________________________________________________________________________________________________
capsnet (Length)                (None, 8)            0          digitcaps[0][0]                 
__________________________________________________________________________________________________
decoder (Sequential)            (None, 28, 28, 1)    1394960    mask_1[0][0]                   
==================================================================================================
Total params: 7,904,272
Trainable params: 7,904,272
Non-trainable params: 0
____________________________
此代码实在keras官方教程下修改而成:
https://blog.csdn.net/wyx100/article/details/80724501
用CapsNets做电能质量扰动分类(2019-08-05)的更多相关文章
- 单向LSTM笔记, LSTM做minist数据集分类
		
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
 - 做一个logitic分类之鸢尾花数据集的分类
		
做一个logitic分类之鸢尾花数据集的分类 Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例.数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都 ...
 - MyBatis 配置/注解 SQL CRUD 经典解决方案(2019.08.15持续更新)
		
本文旨在记录使用各位大神的经典解决方案. 2019.08.14 更新 Mybatis saveOrUpdate SelectKey非主键的使用 MyBatis实现SaveOrUpdate mybati ...
 - http://tedhacker.top/2016/08/05/Spring%E7%BA%BF%E7%A8%8B%E6%B1%A0%E4%BD%BF%E7%94%A8%E6%96%B9%E6%B3%95/
		
http://tedhacker.top/2016/08/05/Spring%E7%BA%BF%E7%A8%8B%E6%B1%A0%E4%BD%BF%E7%94%A8%E6%96%B9%E6%B3%9 ...
 - 新手C#类、对象、字段、方法的学习2018.08.05
		
类:具有相似属性和方法的对象的集合,如“人”是个类. 对象(实例):对象是具体的看得见摸得着的,如“张三”是“人”这个类的对象.(new Person()开辟了堆空间中,=开辟了栈空间,变量P存放在该 ...
 - 【2019年05月20日】A股滚动市盈率PE历史新低排名
		
2010年01月01日 到 2019年05月20日 之间,滚动市盈率历史新低排名. 上市三年以上的公司, 2019年05月20日市盈率在300以下的公司. 1 - 阳光照明(SH600261) - 历 ...
 - 2019.07.05 纪中_B
		
今日膜拜:czj大佬orz%%% 2019.07.05[NOIP提高组]模拟 B 组 今天做题的时候大概能判断出题人的考点,可是就是没学过...特别痛苦 T0:栈的定义,模拟就好了T1:感觉像是找规律 ...
 - http://www.blogjava.net/xylz/archive/2013/08/05/402405.html
		
http://www.blogjava.net/xylz/archive/2013/08/05/402405.html
 - 新手C#s.Split(),s.Substring(,)以及读取txt文件中的字符串的学习2018.08.05
		
s.split()用于字符串分割,具有多种重载方法,可以通过指定字符或字符串分割原字符串成为字符串数组. //s.Split()用于分割字符串为字符串数组,StringSplitOptions.Rem ...
 
随机推荐
- Codeforces Educational Codeforces Round 54 题解
			
题目链接:https://codeforc.es/contest/1076 A. Minimizing the String 题意:给出一个字符串,最多删掉一个字母,输出操作后字典序最小的字符串. 题 ...
 - Python 15__屏幕抓取
 - 1~100卡特兰数(存一下hhhh)
			
1 2 5 14 42 132 429 1430 4862 16796 58786 208012 742900 2674440 9694845 35357670 129644790 477638700 ...
 - DRF-视图类组件
			
补充 GET books-------->查看数据--------------------> 返回所有数据列表 :[{},{},{}] POST books-------->添加数 ...
 - RNN(二)——基于tensorflow的LSTM的实现
			
lstm的前向结构,不迭代 最基本的lstm结构.不涉及损失值和bp过程 import tensorflow as tf import numpy as np inputs = tf.placehol ...
 - 一图了解DLL和SYS的区别
			
The following diagram shows the device node, kernel-mode device stack, and the user-mode device stac ...
 - IDEA算法导包后 import javax.crypto.Cipher; import javax.crypto.KeyGenerator; import javax.crypto.SecretKey;报错
			
仔细查看报错原因就能知道,报错是因为包冲突的原因,可以每种只放一个jar包,就能过避免这种错误. 例如:只导入commons-codec-1.11-javadoc,jar和bcprov-jdk15on ...
 - CRT小键盘输入乱码
			
Options --> Session Options --> Terminal --> Emulation --> Modes 去选中 Enable keypad mode ...
 - vagrant系列三:vagrant搭建的php7环境
			
原文:https://blog.csdn.net/hel12he/article/details/51107236 前面已经把vagrant的基础知识已经基本过了一遍 了,相信只要按着教程来,你已经搭 ...
 - MySQL索引选择及添加原则
			
索引选择性就是结果个数与总个数的比值. 用sql语句表示为: SELECT COUNT(*) FROM table_name WHERE column_name/SELECT COUNT(*) FRO ...