# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. switch嵌套--猜拳游戏

    <!DOCTYPE html> <html>     <head>         <meta charset="UTF-8">   ...

  2. 前端笔记二:CSS盒模型

    1.标准模型和IE模型 2.标准模型和IE模型的区别 标准模型的height和width只是content的: IE模型的height和width是包含padding和border的 3.CSS如何设 ...

  3. 微信小程序开发踩坑日记

    2017.12.29  踩坑记录 引用图片名称不要使用中文,尽量使用中文命名,IDE中图片显示无异样,手机上图片可能出现不显示的情况. 2018.1.5  踩坑记录 微信小程序设置元素满屏,横向直接w ...

  4. [Unity算法]斜抛运动

    斜抛运动: 1.物体以一定的初速度斜向射出去,物体所做的这类运动叫做斜抛运动. 2.斜抛运动看成是作水平方向的匀速直线运动和竖直方向的竖直上抛运动的合运动. 3.它的运动轨迹是抛物线. Oblique ...

  5. 自定义UITableViewCell上的delete按钮

    http://blog.csdn.net/xiaoxuan415315/article/details/7834940

  6. 8.Appium的基本使用-2(安装node.js)

    node.js 下载地址:https://nodejs.org/en/download/下载 64-bit 下载包下载完成双击安装:

  7. jquery给按钮绑定事件

    JQuery: $(function(){ $("#btn1").bind("click",function(){ $("#div1").s ...

  8. Linux设置时间

    设置时间为2017年5月18号9:55:15 date -s "2017-05-18 09:55:15" 修改完后执行clock -w,把系统时间写入CMOS clock -w

  9. linux常用命令以及快捷键

    find命令查找某些文件并将其拷贝到指定目录 [root@host lib]# find -name "*hbase*.jar" |xargs -i cp {}  /root/aa ...

  10. 36. Oracle查询数据库中所有表的记录数

    select t.table_name,t.num_rows from user_tables t