TorchLens:可用于可视化任何PyTorch模型,一个包用于在一行代码中提取和映射PyTorch模型中每个张量运算的结果。TorchLens功能非常强大,如果能够熟练掌握,算是可视化PyTorch模型的一把利剑。本文通过TorchLens可视化一个简单神经网络,算是抛砖引玉吧。

一.定义一个简单神经网络

import torch
import torch.nn as nn
import torch.optim as optim
import torchlens as tl
import os
os.environ["PATH"] += os.pathsep + 'D:/Program Files/Graphviz/bin/' # 定义神经网络类
class NeuralNetwork(nn.Module): # 继承nn.Module类
def __init__(self, input_size, hidden_size, output_size):
super(NeuralNetwork, self).__init__() # 调用父类的构造函数
# 定义输入层到隐藏层的线性变换
self.input_to_hidden = nn.Linear(input_size, hidden_size)
# 定义隐藏层到输出层的线性变换
self.hidden_to_output = nn.Linear(hidden_size, output_size)
# 定义激活函数
self.sigmoid = nn.Sigmoid() def forward(self, x):
# 前向传播
hidden = self.sigmoid(self.input_to_hidden(x))
output = self.sigmoid(self.hidden_to_output(hidden))
return output def NeuralNetwork_train(model):
# 训练神经网络
for epoch in range(10000):
optimizer.zero_grad() # 清零梯度
outputs = model(input_data) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播和优化
optimizer.step() # 更新参数 # 每100个epoch打印一次损失
if (epoch + 1) % 1000 == 0:
print(f'Epoch [{epoch + 1}/10000], Loss: {loss.item():.4f}') return model def NeuralNetwork_test(model):
# 在训练后,可以使用模型进行预测
with torch.no_grad():
test_input = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
predictions = model(test_input)
predicted_labels = (predictions > 0.5).float()
print("Predictions:", predicted_labels) if __name__ == '__main__':
# 定义神经网络的参数
input_size = 2 # 输入特征数量
hidden_size = 4 # 隐藏层神经元数量
output_size = 1 # 输出层神经元数量 # 创建神经网络实例
model = NeuralNetwork(input_size, hidden_size, output_size) # 定义损失函数和优化器
criterion = nn.BCELoss() # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.1) # 随机梯度下降优化器 # 准备示例输入数据和标签
input_data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
labels = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32) # model:神经网络模型
# input_data:输入数据
# layers_to_save:需要保存的层
# vis_opt:rolled/unrolled,是否展开循环
model_history = tl.log_forward_pass(model, input_data, layers_to_save='all', vis_opt='unrolled') # 可视化神经网络
print(model_history)
# print(model_history['input_1'].tensor_contents)
# print(model_history['input_1']) tl.show_model_graph(model, input_data) # model = NeuralNetwork_train(model) # 训练神经网络
# NeuralNetwork_test(model) # 测试神经网络

1.神经网络结构

  输入层包括2个神经元,隐藏层包括4个神经元,输出层包括1个神经元。

2.log_forward_pass

  给定输入x,通过模型运行前向传播,并返回一个包含前向传播日志(层激活和相应的层元数据)的ModelHistory对象。如果vis_opt设置为rolled或unrolled并可视化模型图。

3.show_model_graph

  可视化模型图,而不保存任何激活。

4.查看神经网络模型参数

权重(12)+偏置(5)共计17个参数,如下所示:

二.输出结果分析

1.model_history输出结果

Log of NeuralNetwork forward pass: // 神经网络前向传播日志
Random seed: 1626722175 // 随机种子
Time elapsed: 1.742s (1.74s spent logging) // 耗时
Structure: // 结构
- purely feedforward, no recurrence // 纯前馈,无循环
- no branching // 无分支
- no conditional (if-then) branching // 无条件(if-then)分支
- 3 total modules // 3个模块
Tensor info: // 张量信息
- 6 total tensors (976 B) computed in forward pass. // 前向传播中计算的6个张量(976 B)
- 6 tensors (976 B) with saved activations. // 6个张量(976 B)保存了激活
Parameters: 2 parameter operations (17 params total; 548 B) // 参数:2个参数操作(总共17个参数;548 B)
Module Hierarchy: // 模块层次
input_to_hidden // 输入到隐藏
sigmoid:1 // sigmoid:1
hidden_to_output // 隐藏到输出
sigmoid:2 // sigmoid:2
Layers (all have saved activations): // 层(所有层都有保存的激活)
(0) input_1 // 输入
(1) linear_1_1 // 线性
(2) sigmoid_1_2 // sigmoid
(3) linear_2_3 // 线性
(4) sigmoid_2_4 // sigmoid
(5) output_1 // 输出

2.show_model_graph输出结果



(1)总共包含6层

  分别为input_1、linear_1_1、sigmoid_1_2、linear_2_3、sigmoid_2_4和output_1。

(2)总共6个张量

  指的是input_1(160B)、linear_1_1(192B)、sigmoid_1_2(192B)、linear_2_3(144B)、sigmoid_2_4(144B)和output_1(144B)。共计976B。

(3)input_1 4*2(160B)

  4*2表示input_1的shape,而160B指的是该张量在内存中占用空间大小,以字节(B)为单位。知道张量的形状和内存占用情况,对于模型内存管理和优化来说是很有用的信息。其它张量信息如下所示:



(4)共计17参数

  linear_1_1参数信息为42和4,linear_1_1参数信息为14和1,共计17参数,内存占用548B。

三.遇到的问题

1.需要安装和设置graphviz

subprocess.CalledProcessError: Command '[WindowsPath('dot'), '-Kdot', '-Tpdf', '-O', 'graph.gv']' returned non-zero exit status 1.

解决方案是将D:\Program Files\Graphviz\bin添加到系统环境变量PATH中。

2.AlexNet神经网络

因为BP神经网络过于简单,接下来可视化一个稍微复杂点儿的AlexNet神经网络,如下所示:

参考文献:

[1]torchlens_tutorial.ipynb:https://colab.research.google.com/drive/1ORJLGZPifvdsVPFqq1LYT3t5hV560SoW?usp=sharing#scrollTo=W_94PeNdQsUN

[2]Extracting and visualizing hidden activations and computational graphs of PyTorch models with TorchLens:https://www.nature.com/articles/s41598-023-40807-0

[3]torchlens:https://github.com/johnmarktaylor91/torchlens

[4]Torchlens Model Menagerie:https://drive.google.com/drive/folders/1BsM6WPf3eB79-CRNgZejMxjg38rN6VCb

[5]使用TorchLens可视化一个简单的神经网络:github.com/ai408/nlp-engineering/tree/main/20230917_NLP工程化公众号文章/使用torchlens可视化一个简单的神经网络

使用TorchLens可视化一个简单的神经网络的更多相关文章

  1. tensorflow笔记(二)之构造一个简单的神经网络

    tensorflow笔记(二)之构造一个简单的神经网络 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7425200.html ...

  2. Python实现一个简单三层神经网络的搭建并测试

    python实现一个简单三层神经网络的搭建(有代码) 废话不多说了,直接步入正题,一个完整的神经网络一般由三层构成:输入层,隐藏层(可以有多层)和输出层.本文所构建的神经网络隐藏层只有一层.一个神经网 ...

  3. python日记:用pytorch搭建一个简单的神经网络

    最近在学习pytorch框架,给大家分享一个最最最最基本的用pytorch搭建神经网络并且训练的方法.本人是第一次写这种分享文章,希望对初学pytorch的朋友有所帮助! 一.任务 首先说下我们要搭建 ...

  4. pytorch定义一个简单的神经网络

    刚学习pytorch,简单记录一下 """ test Funcition """ import torch from torch.autog ...

  5. 使用RStudio学习一个简单神经网络

    数据准备 1.收集数据 UC Irvine Machine Learning Repository-Concrete Compressive Strength Data Set 把下载到的Concre ...

  6. 从程序员的角度设计一个Java的神经网络

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 来自维基百科: 人工神经网络(ANN)或连接系统是受生物神经网络启发构成生物大脑的计算系统.这样的系统通过考虑例子来学习(逐步提高性能)来完成任 ...

  7. pytorch学习笔记(8)--搭建简单的神经网络以及Sequential的使用

    1.神经网络图 输入图像是3通道的32×32的,先后经过卷积层(5×5的卷积核).最大池化层(2×2的池化核).卷积层(5×5的卷积核).最大池化层(2×2的池化核).卷积层(5×5的卷积核).最大池 ...

  8. C++从零实现简单深度神经网络(基于OpenCV)

    代码地址如下:http://www.demodashi.com/demo/11138.html 一.准备工作 需要准备什么环境 需要安装有Visual Studio并且配置了OpenCV.能够使用Op ...

  9. 使用Python来编写一个简单的感知机

    来表示.第二个元素是表示期望输出的值. 这个数组定义例如以下: training_data = [  (array([0,0,1]), 0),  (array([0,1,1]), 1),  (arra ...

  10. tensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    刚开始学习tf时,我们从简单的地方开始.卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始. 神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输 ...

随机推荐

  1. HyperPlatform

    之前也写过一个vt的框架,但是比较简单,写的比较乱迁移什么的比较麻烦,于是阅读下HyperPlatform的源码学习下. 本文只对主体框架分析. vt的流程大概如下 1:检测是否支持VT. 2:vmx ...

  2. 【python基础】复杂数据类型-列表类型(列表切片)

    1.列表切片 前面学习的是如何处理列表的所有数据元素.python还可以处理列表的部分元素,python称之为切片. 1.1创建切片 创建切片,可指定要使用的第一个数据元素的索引和最后一个数据元素的索 ...

  3. 如何制作 Storybook Day 网页上的 3D 效果?

    Storybook 刚刚达到了一个重要的里程牌:7.0 版本!为了庆祝,该团队举办了他们的第一次用户大会 - Storybook Day.为了更特别,在活动页面中添加了一个视觉上令人惊叹的 3D 插图 ...

  4. 拥抱jsx,开启vue3用法的另一种选择🔥🔥

    背景 公司高级表单组件ProForm高阶组件都建立在jsx的运用配置上,项目在实践落地过程中积累了丰富的经验,也充分感受到了jsx语法的灵活便捷和可维护性强大,享受到了用其开发的乐趣,独乐乐不如众乐乐 ...

  5. IM1281B电能计量模块_C语言例程

    一.前言 毕设采用了艾锐达公司的IM1281B电量计能模块,找了一圈没发现具体的51单片机的例程,现在写个能使用的C语言例程,方便以后的开发者们. 二.事前准备 引脚定义: 引脚 功能说明 V+ 供电 ...

  6. 喜报 | ShowMeBug获国家高新技术企业认证!

    近日,深圳至简天成科技有限公司(以下简称至简天成)顺利通过国家高新技术企业认证! 国家高新技术企业是由国务院主导.科技部牵头的国家级荣誉资质,是我国科技类企业中的"国"字号招牌,完 ...

  7. C++面试八股文:std::deque用过吗?

    某日二师兄参加XXX科技公司的C++工程师开发岗位第26面: 面试官:deque用过吗? 二师兄:说实话,很少用,基本没用过. 面试官:为什么? 二师兄:因为使用它的场景很少,大部分需要性能.且需要自 ...

  8. 把langchain跑起来的3个方法

    使用LangChain开发LLM应用时,需要机器进行GLM部署,好多同学第一步就被劝退了,那么如何绕过这个步骤先学习LLM模型的应用,对Langchain进行快速上手?本片讲解3个把LangChain ...

  9. zip文件结构

    转starshine博客 一个zip文件由三个部分组成:压缩源文件数据区.压缩源文件目录区.压缩源文件目录结束标志 压缩源文件数据区: 50 4B 03 04:这是头文件标记(0x04034b50) ...

  10. Sa-Token 多账号认证:同时为系统的 Admin 账号和 User 账号提供鉴权操作

    Sa-Token 是一个轻量级 java 权限认证框架,主要解决登录认证.权限认证.单点登录.OAuth2.微服务网关鉴权 等一系列权限相关问题. Gitee 开源地址:https://gitee.c ...