导入必要的库:

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
tf.random.set_seed(2345)

其中os.environ部分是为了减少Tensorflow打印的信息

构建网络结构:

conv_layers=[
layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2,2],strides=2,padding="same"), layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"), layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
]

优化器:

def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
y=tf.cast(y,dtype=tf.int32)
return x,y

加载数据:

这里使用比较常见的CIFAR10的数据集

(x_train,y_train),(x_test,y_test)=datasets.cifar10.load_data()
y_train=tf.squeeze(y_train,axis=1)
y_test=tf.squeeze(y_test,axis=1)
# print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
train_data=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_data=train_data.shuffle(1000).map(preprocess).batch(64) test_data=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_data=test_data.map(preprocess).batch(64) sample=next(iter(train_data))
print('sample:',sample[0].shape,sample[1].shape,
tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))

sample=next(iter(train_data))

这一部分是打印train_data的信息

完善网络:

def main():
conv_net=Sequential(conv_layers)
# x=tf.random.normal([4,32,32,3])
# out=conv_net(x)
# print(out.shape)
fc_net=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10,activation=None),
])
conv_net.build(input_shape=[None, 32, 32, 3])
fc_net.build(input_shape=[None,512])
optimizer=optimizers.Adam(lr=1e-4)

计算loss:

variables=conv_net.trainable_variables+fc_net.trainable_variables
for epoch in range(50):
for step,(x,y) in enumerate(train_data):
with tf.GradientTape() as tape:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
y_onehot=tf.one_hot(y,depth=10)
loss=tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss=tf.reduce_mean(loss)
grads=tape.gradient(loss,variables)
optimizer.apply_gradients(zip(grads,variables))
if step%100==0:
print(epoch,step,'loss',float(loss))

测试:

total_num=0
total_correct=0
for x,y in test_data:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
prob=tf.nn.softmax(logits,axis=1)
pred=tf.argmax(prob,axis=1)
pred=tf.cast(pred,dtype=tf.int32)
correct=tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct=tf.reduce_sum(correct)
total_num+=x.shape[0]
total_correct+=int(correct)
acc=total_correct/total_num
print(epoch,'acc:',acc)
if __name__ == '__main__':
main()

训练数据:

0 0 loss 2.302990436553955
0 100 loss 1.9521405696868896
0 200 loss 1.9435423612594604
0 300 loss 1.6067744493484497
0 400 loss 1.5959546566009521
0 500 loss 1.734712839126587
0 600 loss 1.2384529113769531
0 700 loss 1.3307044506072998
0 acc: 0.4787
5 0 loss 0.6936513185501099
5 100 loss 0.7874761819839478
5 200 loss 0.7884306907653809
5 300 loss 0.6663026809692383
5 400 loss 0.4075947105884552
5 500 loss 0.6752095222473145
5 600 loss 0.5246847867965698
5 700 loss 0.5275574922561646
5 acc: 0.7299
10 0 loss 0.7874808311462402
10 100 loss 0.5072851181030273
10 200 loss 0.4451877772808075
10 300 loss 0.177499920129776
10 400 loss 0.13723205029964447
10 500 loss 0.2971668243408203
10 600 loss 0.25279730558395386
10 700 loss 0.36453887820243835
10 acc: 0.7355
15 0 loss 0.2800075113773346
15 100 loss 0.1841358095407486
15 200 loss 0.040746696293354034
15 300 loss 0.06615383923053741
15 400 loss 0.1183178648352623
15 500 loss 0.07481158524751663
15 600 loss 0.09398414194583893
15 700 loss 0.03665520250797272
15 acc: 0.7469
20 0 loss 0.02290465496480465
20 100 loss 0.008633529767394066
20 200 loss 0.21534058451652527
20 300 loss 0.011568240821361542
20 400 loss 0.08179830759763718
20 500 loss 0.02673691138625145
20 600 loss 0.06506452709436417
20 700 loss 0.026200752705335617
20 acc: 0.7621

训练大概50epoch,这里仅仅展示20个,可以看到,验证准确率是在不断的上升的,后面的数据就不展示了,我也没训练完,有兴趣的可以接着跑将模型保存一下,有时间再接着训练

Tensorflow2.0实现VGG13的更多相关文章

  1. 基于tensorflow2.0 使用tf.keras实现Fashion MNIST

    本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...

  2. Google工程师亲授 Tensorflow2.0-入门到进阶

    第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...

  3. TensorFlow2.0(1):基本数据结构—张量

    1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...

  4. 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程

    0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. ​ 本篇文章就 ...

  5. 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备

    此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...

  6. TensorFlow2.0(9):TensorBoard可视化

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  7. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. tensorflow2.0安装

    版本: python3.5 Anaconda 4.2.0 tensorflow2.0 cpu版本 1.安装命令 pip3 install tensorflow==2.0.0.0a0 -i https: ...

  9. TensorFlow2.0初体验

    TF2.0默认为动态图,即eager模式.意味着TF能像Pytorch一样不用在session中才能输出中间参数值了,那么动态图和静态图毕竟是有区别的,tf2.0也会有写法上的变化.不过值得吐槽的是, ...

  10. tensorflow2.0 学习(三)

    用tensorflow2.0 版回顾了一下mnist的学习 代码如下,感觉这个版本下的mnist学习更简洁,更方便 关于tensorflow的基础知识,这里就不更新了,用到什么就到网上取搜索相关的知识 ...

随机推荐

  1. 《深入理解Java虚拟机》读书笔记:基于栈的字节码解释执行引擎

      虚拟机是如何调用方法的内容已经讲解完毕,从本节开始,我们来探讨虚拟机是如何执行方法中的字节码指令的.上文中提到过,许多Java虚拟机的执行引擎在执行Java代码的时候都有解释执行(通过解释器执行) ...

  2. 两种方式,轻松实现ChatGPT联网

    两种方式效果: 方式一:浏览器搜索内嵌插件 方式二:官方聊天页内嵌插件 首先,要有一个谷歌浏览器,然后再安装一个叫ChatGPT for Google,直接在谷歌里搜一下就能找,也可以Chrome应用 ...

  3. 论文解读(TAMEPT)《A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classification》

    论文信息 论文标题:A Two-Stage Framework with Self-Supervised Distillation For Cross-Domain Text Classificati ...

  4. python入门基础(13)--类、对象

    面向过程的编程语言,如C语言,所使用的数据和函数之间是没有任何直接联系的,它们之间是通过函数调用提供参数的形式将数据传入函数进行处理. 但可能因为错误的传递参数.错误地修改了数据而导致程序出错,甚至是 ...

  5. 「loj - 3022」「cqoi 2017」老 C 的方块

    link. good题,考虑像 国家集训队 - happiness 一样在棋盘上搞染色,我毛张 @shadowice1987 的图给你看啊 你像这样奇数层以 red -> blue -> ...

  6. Redis持久化 (RDB和AOF) 梳理

    Redis有两种持久化方案: RDB持久化 AOF持久化 RDB持久化 RDB全称Redis Database Backup file(Redis数据备份文件),也被叫做Redis数据快照.简单来说就 ...

  7. destoon上做纯js实现html指定页面导出word

    因为最近做了范文网站需要,所以要下载为word文档,如果php进行处理,很吃后台服务器,所以想用前端进行实现.查询github发现,确实有这方面的插件. js导出word文档所需要的两个插件: 1 2 ...

  8. Go 常用标准库之 fmt 介绍与基本使用

    Go 常用标准库之 fmt 介绍与基本使用 目录 Go 常用标准库之 fmt 介绍与基本使用 一.介绍 二.向外输出 2.1 Print 系列 2.2 Fprint 系列 2.3 Sprint 系列 ...

  9. Kafka 在分布式系统中的 7 大应用场景

    Kafka 介绍 Kafka 是一个开源的分布式流式平台,它可以处理大量的实时数据,并提供高吞吐量,低延迟,高可靠性和高可扩展性.Kafka 的核心组件包括生产者(Producer),消费者(Cons ...

  10. 长程 Transformer 模型

    Tay 等人的 Efficient Transformers taxonomy from Efficient Transformers: a Survey 论文 本文由 Teven Le Scao.P ...