# -*- coding=utf-8 -*-
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense,Flatten,Dropout
from keras.optimizers import Adadelta
from keras.datasets import cifar10
from keras import applications

import matplotlib.pyplot as plt
%matplotlib inline

vgg_model=applications.VGG19(include_top=False,weights='imagenet')
vgg_model.summary()

(train_x,train_y),(test_x,test_y)=cifar10.load_data()
print(train_x.shape,train_y.shape,test_x.shape,test_y.shape)

n_classes=10
train_y=keras.utils.to_categorical(train_y,n_classes)
test_y=keras.utils.to_categorical(test_y,n_classes)

bottleneck_feature_train=vgg_model.predict(train_x,verbose=1)
bottleneck_feature_test=vgg_model.predict(test_x,verbose=1)

print(bottleneck_feature_train.shape,bottleneck_feature_test.shape)

my_model=Sequential()
my_model.add(Flatten())###my_model.add(Flatten(input_shape=?))
my_model.add(Dense(512,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(256,activation='relu'))
my_model.add(Dropout(0.5))
my_model.add(Dense(n_classes,activation='softmax'))
my_model.compile(optimizer=Adadelta(),loss="categorical_crossentropy",\
metrics=['accuracy'])
my_model.fit(bottleneck_feature_train,train_y,batch_size=128,epochs=50,verbose=1)

evaluation=my_model.evaluate(bottleneck_feature_test,test_y,batch_size=128,verbose=0)
print("loss:",evaluation[0],"accuracy:",evaluation[1])

def predict_label(img_idx,show_proba=True):
plt.imshow(train_x[img_idx],aspect='auto')
plt.title("Image to be labeled")
plt.show()
img_4D=(bottleneck_feature_train[img_idx])[np.newaxis,:,:,:]
prediction=my_model.predict_classes(img_4D,batch_size=1,verbose=0)
print("Actual class:{0}\nPredict class:{1}".format(np.argmax(train_y[img_idx],0),prediction))

if show_proba:
pred=my_model.predict_proba(img_4D,batch_size=1,verbose=0)
print(pred)

for i in range(3):
predict_label(i,show_proba=True)

吴裕雄 python神经网络(8)的更多相关文章

  1. 吴裕雄 python神经网络 花朵图片识别(10)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  2. 吴裕雄 python神经网络 花朵图片识别(9)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageChopsfrom skim ...

  3. 吴裕雄 python神经网络 手写数字图片识别(5)

    import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...

  4. 吴裕雄 python神经网络 水果图片识别(4)

    # coding: utf-8 # In[1]:import osimport numpy as npfrom skimage import color, data, transform, io # ...

  5. 吴裕雄 python神经网络 水果图片识别(3)

    import osimport kerasimport timeimport numpy as npimport tensorflow as tffrom random import shufflef ...

  6. 吴裕雄 python神经网络 水果图片识别(2)

    import osimport numpy as npimport matplotlib.pyplot as pltfrom skimage import color,data,transform,i ...

  7. 吴裕雄 python 神经网络——TensorFlow 循环神经网络处理MNIST手写数字数据集

    #加载TF并导入数据集 import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tuto ...

  8. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  9. 吴裕雄 python 神经网络——TensorFlow 训练过程的可视化 TensorBoard的应用

    #训练过程的可视化 ,TensorBoard的应用 #导入模块并下载数据集 import tensorflow as tf from tensorflow.examples.tutorials.mni ...

  10. 吴裕雄 python 神经网络——TensorFlow实现搭建基础神经网络

    import numpy as np import tensorflow as tf import matplotlib.pyplot as plt def add_layer(inputs, in_ ...

随机推荐

  1. ES6学习笔记<二>arrow functions 箭头函数、template string、destructuring

    接着上一篇的说. arrow functions 箭头函数 => 更便捷的函数声明 document.getElementById("click_1").onclick = ...

  2. js: 字符集

    用js生成字符集 一般网页制作中需要一些向上向下的小箭头,用图片非常不合算(一个页面发起多个http请求.css.文件大小等方面考虑) 所以用一些字符集的字符图形,效果很好 下面是用js生成字符集,以 ...

  3. spring 之 factory-bean & factory-method

    这两者常常是一起出现的,或者说他们经常是一起被使用的.但是其实是分为了两种情况: 1 同时使用factory-bean 和 factory-method 如果,我们在一个bean 元素上同时配置 fa ...

  4. Mybatis学习5标签:if,where,sql,foreach

    包装类:QueryVO.java package pojo; import java.util.ArrayList; import java.util.List; public class Query ...

  5. 1. @ModelAttribute注解

    添加@ModelAttribute修饰的方法,在每个目标方法调用前都会执行该方法. 一般情况下,在form表单修改的时,某项字段规定为不可更改,就需要使用该注解标注的方法,根据id的获取与否,来从数据 ...

  6. journalctl

    systemd 提供了自己的日志系统(logging system),称为 journal.使用 systemd 日志,无需额外安装日志服务(syslog).读取日志的命令: # journalctl ...

  7. 使用expect解决shell交互问题

    比如ssh的时候,如果没设置免密登陆,那么就需要输入密码.使用expect可以做成自动应答 1.expect检测和安装 sudo apt-get install tcl tk expect 2.脚本样 ...

  8. centos LVM详解

    title: centos LVM详解 date: 2018-04-24 14:00:03 tags: [linux,centos,LVM] --- 知识了解 LVM关系图 fdisk命令详解 [ro ...

  9. smarty获取php中的变量

    {$smarty}保留变量不需要从PHP脚本中分配,是可以在模板中直接访问的数组类型变量,通常被用于访问一些特殊的模板变量.例如,直接在模板中访问页面请求变量.获取访问模板时的时间邮戳.直接访问PHP ...

  10. EmEditor

    姓 名:ttrar.com 序 列 号:DKAZQ-R9TYP-5SM2A-9Z8KD-3E2RK 免费版地址:https://zh-cn.emeditor.com/#download