# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. linux命令之vi文本编辑器

    vi filename :打开或新建文件,并将光标置于第一行首 按i,开始输入(insert) d删除整行 u   撤销上一步的操作Ctrl+r 恢复上一步被撤销的操作 ESC退出输入 按ESC键 跳 ...

  2. html中header,footer分别固定在顶部和底部

    1 <!DOCTYPE html> 2 <html> 3 <head> 4 <title>page01</title> 5 <styl ...

  3. spring boot整合quartz实现多个定时任务

        版权声明:本文为博主原创文章,转载请注明出处. https://blog.csdn.net/liuchuanhong1/article/details/78543574 最近收到了很多封邮件, ...

  4. Git上传项目失败:Push rejected: Push to origin/master was rejected

    解决方案如下: 1.切换到自己项目所在的目录,右键选择GIT BASH Here,Idea中可使用Alt+F12 打开终端 2.在terminl窗口中依次输入命令: git pull git pull ...

  5. mysql给查询的结果添加序号

    1.法一: select  (@i:=@i+1)  i,a.url from  base_api_resources a  ,(select   @i:=0)  t2 order by a.id de ...

  6. oracle 多行合并为一行

    sys_connect_by_path select i,ltrim(max(sys_connect_by_path(a,',')),',') afrom(select i,a,d,min(d) ov ...

  7. 制作签名jar放置到前端资源目录下

    给jar包打签名keytool -genkey -keystore myKeystore -alias jwstest查看签名信息jarsigner -keystore myKeystore data ...

  8. 【Social listening实操】用大数据文本挖掘,来洞察“共享单车”的行业现状及走势

    本文转自知乎 作者:苏格兰折耳喵 ----------------------------------------------------- 对于当下共享单车在互联网界的火热状况,笔者想从大数据文本挖 ...

  9. 33.纯 CSS 创作牛奶文字变换效果

    原文地址:https://segmentfault.com/a/1190000015037234 感想:transform: translateY(50% & -50%);  animatio ...

  10. 转载:Bootstrap 源码解析

    Bootstrap 源码解析 前言 Bootstrap 是个CSS库,简单,高效.很多都可以忘记了再去网站查.但是有一些核心的东西需要弄懂.个人认为弄懂了这些应该就算是会了.源码看一波. 栅格系统 所 ...