Tensorflow2.0实现VGG13
导入必要的库:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
tf.random.set_seed(2345)
其中os.environ部分是为了减少Tensorflow打印的信息
构建网络结构:
conv_layers=[
layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2,2],strides=2,padding="same"),
layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
]
优化器:
def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
y=tf.cast(y,dtype=tf.int32)
return x,y
加载数据:
这里使用比较常见的CIFAR10的数据集
(x_train,y_train),(x_test,y_test)=datasets.cifar10.load_data()
y_train=tf.squeeze(y_train,axis=1)
y_test=tf.squeeze(y_test,axis=1)
# print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
train_data=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_data=train_data.shuffle(1000).map(preprocess).batch(64)
test_data=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_data=test_data.map(preprocess).batch(64)
sample=next(iter(train_data))
print('sample:',sample[0].shape,sample[1].shape,
tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))
sample=next(iter(train_data))
这一部分是打印train_data的信息
完善网络:
def main():
conv_net=Sequential(conv_layers)
# x=tf.random.normal([4,32,32,3])
# out=conv_net(x)
# print(out.shape)
fc_net=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10,activation=None),
])
conv_net.build(input_shape=[None, 32, 32, 3])
fc_net.build(input_shape=[None,512])
optimizer=optimizers.Adam(lr=1e-4)
计算loss:
variables=conv_net.trainable_variables+fc_net.trainable_variables
for epoch in range(50):
for step,(x,y) in enumerate(train_data):
with tf.GradientTape() as tape:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
y_onehot=tf.one_hot(y,depth=10)
loss=tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss=tf.reduce_mean(loss)
grads=tape.gradient(loss,variables)
optimizer.apply_gradients(zip(grads,variables))
if step%100==0:
print(epoch,step,'loss',float(loss))
测试:
total_num=0
total_correct=0
for x,y in test_data:
out=conv_net(x)
out=tf.reshape(out,[-1,512])
logits=fc_net(out)
prob=tf.nn.softmax(logits,axis=1)
pred=tf.argmax(prob,axis=1)
pred=tf.cast(pred,dtype=tf.int32)
correct=tf.cast(tf.equal(pred,y),dtype=tf.int32)
correct=tf.reduce_sum(correct)
total_num+=x.shape[0]
total_correct+=int(correct)
acc=total_correct/total_num
print(epoch,'acc:',acc)
if __name__ == '__main__':
main()
训练数据:
0 0 loss 2.302990436553955
0 100 loss 1.9521405696868896
0 200 loss 1.9435423612594604
0 300 loss 1.6067744493484497
0 400 loss 1.5959546566009521
0 500 loss 1.734712839126587
0 600 loss 1.2384529113769531
0 700 loss 1.3307044506072998
0 acc: 0.4787
5 0 loss 0.6936513185501099
5 100 loss 0.7874761819839478
5 200 loss 0.7884306907653809
5 300 loss 0.6663026809692383
5 400 loss 0.4075947105884552
5 500 loss 0.6752095222473145
5 600 loss 0.5246847867965698
5 700 loss 0.5275574922561646
5 acc: 0.7299
10 0 loss 0.7874808311462402
10 100 loss 0.5072851181030273
10 200 loss 0.4451877772808075
10 300 loss 0.177499920129776
10 400 loss 0.13723205029964447
10 500 loss 0.2971668243408203
10 600 loss 0.25279730558395386
10 700 loss 0.36453887820243835
10 acc: 0.7355
15 0 loss 0.2800075113773346
15 100 loss 0.1841358095407486
15 200 loss 0.040746696293354034
15 300 loss 0.06615383923053741
15 400 loss 0.1183178648352623
15 500 loss 0.07481158524751663
15 600 loss 0.09398414194583893
15 700 loss 0.03665520250797272
15 acc: 0.7469
20 0 loss 0.02290465496480465
20 100 loss 0.008633529767394066
20 200 loss 0.21534058451652527
20 300 loss 0.011568240821361542
20 400 loss 0.08179830759763718
20 500 loss 0.02673691138625145
20 600 loss 0.06506452709436417
20 700 loss 0.026200752705335617
20 acc: 0.7621
训练大概50epoch,这里仅仅展示20个,可以看到,验证准确率是在不断的上升的,后面的数据就不展示了,我也没训练完,有兴趣的可以接着跑将模型保存一下,有时间再接着训练
Tensorflow2.0实现VGG13的更多相关文章
- 基于tensorflow2.0 使用tf.keras实现Fashion MNIST
本次使用的是2.0测试版,正式版估计会很快就上线了 tf2好像更新了蛮多东西 虽然教程不多 还是找了个试试 的确简单不少,但是还是比较喜欢现在这种写法 老样子先导入库 import tensorflo ...
- Google工程师亲授 Tensorflow2.0-入门到进阶
第1章 Tensorfow简介与环境搭建 本门课程的入门章节,简要介绍了tensorflow是什么,详细介绍了Tensorflow历史版本变迁以及tensorflow的架构和强大特性.并在Tensor ...
- TensorFlow2.0(1):基本数据结构—张量
1 引言 TensorFlow2.0版本已经发布,虽然不是正式版,但预览版都发布了,正式版还会远吗?相比于1.X,2.0版的TensorFlow修改的不是一点半点,这些修改极大的弥补了1.X版本的反人 ...
- 『TensorFlow2.0正式版教程』极简安装TF2.0正式版(CPU&GPU)教程
0 前言 TensorFlow 2.0,今天凌晨,正式放出了2.0版本. 不少网友表示,TensorFlow 2.0比PyTorch更好用,已经准备全面转向这个新升级的深度学习框架了. 本篇文章就 ...
- 『TensorFlow2.0正式版』TF2.0+Keras速成教程·零:开篇简介与环境准备
此篇教程参考自TensorFlow 2.0 + Keras Crash Course,在原文的基础上进行了适当的总结与改编,以适应于国内开发者的理解与使用,水平有限,如果写的不对的地方欢迎大家评论指出 ...
- TensorFlow2.0(9):TensorBoard可视化
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- TensorFlow2.0(11):tf.keras建模三部曲
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- tensorflow2.0安装
版本: python3.5 Anaconda 4.2.0 tensorflow2.0 cpu版本 1.安装命令 pip3 install tensorflow==2.0.0.0a0 -i https: ...
- TensorFlow2.0初体验
TF2.0默认为动态图,即eager模式.意味着TF能像Pytorch一样不用在session中才能输出中间参数值了,那么动态图和静态图毕竟是有区别的,tf2.0也会有写法上的变化.不过值得吐槽的是, ...
- tensorflow2.0 学习(三)
用tensorflow2.0 版回顾了一下mnist的学习 代码如下,感觉这个版本下的mnist学习更简洁,更方便 关于tensorflow的基础知识,这里就不更新了,用到什么就到网上取搜索相关的知识 ...
随机推荐
- Eclipse修改Web项目名称
Eclipse修改Web项目名称需要两步: 1:修改该项目目录下:.project文件 <projectDescription><name>SpringMVC-Annotati ...
- Domain Admin域名和SSL证书过期监控到期提醒
基于Python3 + Vue3.js 技术栈实现的域名和SSL证书监测平台 用于解决,不同业务域名SSL证书,申请自不同的平台,到期后不能及时收到通知,导致线上访问异常,被老板责骂的问题 核心功能: ...
- Solution -「洛谷 P5659」「CSP-S 2019」树上的数
Description Link. 联赛原题应该都读过吧-- Solution Part 0 大致思路 主要的思路就是逐个打破,研究特殊的数据得到普通的结论. Part 1 暴力的部分分 暴力的部分分 ...
- Ds100p -「数据结构百题」31~40
31.P2163 [SHOI2007]园丁的烦恼] 很久很久以前,在遥远的大陆上有一个美丽的国家.统治着这个美丽国家的国王是一个园艺爱好者,在他的皇家花园里种植着各种奇花异草. 有一天国王漫步在花园里 ...
- Go 语言开发环境搭建
Go 语言开发环境搭建 目录 Go 语言开发环境搭建 一. GO 环境安装 1.1 下载 1.2 Go 版本的选择 1.3 安装 1.3.1 Windows安装 1.3.2 Linux下安装 1.3. ...
- 快速启动Stable Diffusion WebUI
快速启动Stable Diffusion WebUI详情 产品文档 输入文档关键字查找 机器学习PAI 产品概述 快速入门 操作指南 准备工作 开通PAI并创建默认工作空间 开通并授权依 ...
- Sunshine on my shoulders
https://music.163.com/#/song?id=1477706 Sunshine on my shoulders makes me happy照在我肩上的阳光让我欢乐Sunshine ...
- MySQL的index merge(索引合并)导致数据库死锁分析与解决方案
背景 在DBS-集群列表-更多-连接查询-死锁中,看到9月22日有数据库死锁日志,后排查发现是因为mysql的优化-index merge(索引合并)导致数据库死锁. 定义 index merge(索 ...
- SP3377
题目简化和分析: 前言:这题目背景真奇怪. 我们可以将每种关系,看成一条边,如果出现奇数边环就不满足. 例如:\(a,b\) 异性 \(a,c\) 异性 \(b,c\)异性 这种情况是不满足的. 相当 ...
- 从链接器的角度详细分析g++报错: (.text+0x24): undefined reference to `main'
/usr/bin/ld: /usr/lib/gcc/x86_64-linux-gnu/9/../../../x86_64-linux-gnu/Scrt1.o: in function `_start' ...