GCN代码分析

 

1 代码结构

  1. .
  2. ├── data // 图数据
  3. ├── inits // 初始化的一些公用函数
  4. ├── layers // GCN层的定义
  5. ├── metrics // 评测指标的计算
  6. ├── models // 模型结构定义
  7. ├── train // 训练
  8. └── utils // 工具函数的定义

utils.py

def parse_index_file(filename) # 处理index文件并返回index矩阵

def sample_mask(idx, l) #创建 mask 并返回mask矩阵

def load_data(dataset_str) # 读取数据

  • 从gcn/data文件夹下读取数据,文件包括有:

  • ind.dataset_str.x => 训练实例的特征向量,如scipy.sparse.csr.csr_matrix类的实例

  • ind.dataset_str.tx => 测试实例的特征向量,如scipy.sparse.csr.csr_matrix类的实例

  • ind.dataset_str.allx => 有标签的+无无标签训练实例的特征向量,是ind.dataset_str.x的超集

  • ind.dataset_str.y => 训练实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.ty => 测试实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.ally => 有标签的+无无标签训练实例的标签,独热编码,numpy.ndarray类的实例

  • ind.dataset_str.graph => 图数据,collections.defaultdict类的实例,格式为 {index:[index_of_neighbor_nodes]}

  • ind.dataset_str.test.index => 测试实例的id

​ 上述文件必须都用python的pickle模块存储

  • 返回: adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask

def sparse_to_tuple(sparse_mx) # 将矩阵转换成tuple格式并返回

def preprocess_features(features) # 处理特征:将特征进行归一化并返回tuple (coords, values, shape)

def normalize_adj(adj) # 图归一化并返回

def preprocess_adj(adj) # 处理得到GCN中的归一化矩阵并返回

def construct_feed_dict(features, support, labels, labels_mask, placeholders) # 构建输入字典并返回

def chebyshev_polynomials(adj, k) # 切比雪夫多项式近似:计算K阶的切比雪夫近似矩阵

  1. def chebyshev_polynomials(adj, k):
  2. """Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices (tuple representation)."""
  3. print("Calculating Chebyshev polynomials up to order {}...".format(k))
  4. adj_normalized = normalize_adj(adj) # D^{-1/2}AD^{1/2}
  5. laplacian = sp.eye(adj.shape[0]) - adj_normalized # L = I_N - D^{-1/2}AD^{1/2}
  6. largest_eigval, _ = eigsh(laplacian, 1, which='LM') # \lambda_{max}
  7. scaled_laplacian = (2. / largest_eigval[0]) * laplacian - sp.eye(adj.shape[0]) # 2/\lambda_{max}L-I_N
  8. # 将切比雪夫多项式的 T_0(x) = 1和 T_1(x) = x 项加入到t_k中
  9. t_k = list()
  10. t_k.append(sp.eye(adj.shape[0]))
  11. t_k.append(scaled_laplacian)
  12. # 依据公式 T_n(x) = 2xT_n(x) - T_{n-1}(x) 构造递归程序,计算T_2 -> T_k项目
  13. def chebyshev_recurrence(t_k_minus_one, t_k_minus_two, scaled_lap):
  14. s_lap = sp.csr_matrix(scaled_lap, copy=True)
  15. return 2 * s_lap.dot(t_k_minus_one) - t_k_minus_two
  16. for i in range(2, k+1):
  17. t_k.append(chebyshev_recurrence(t_k[-1], t_k[-2], scaled_laplacian))
  18. return sparse_to_tuple(t_k)

layers.py

  • 各层定义的方式与keras类似

  • 定义基类 Layer

    属性:name (String) => 定义了变量范围;logging (Boolean) => 打开或关闭TensorFlow直方图日志记录

    方法:__init__()(初始化),_call()(定义计算),__call__()(调用_call()函数),_log_vars()

  • 定义Dense Layer类,继承自Layer类

  • 定义GraphConvolution类,继承自Layer类。重点来看一下这个类的实现。

  1. class GraphConvolution(Layer):
  2. """Graph convolution layer."""
  3. def __init__(self, input_dim, output_dim, placeholders, dropout=0.,
  4. sparse_inputs=False, act=tf.nn.relu, bias=False,
  5. featureless=False, **kwargs):
  6. super(GraphConvolution, self).__init__(**kwargs)
  7. if dropout:
  8. self.dropout = placeholders['dropout']
  9. else:
  10. self.dropout = 0.
  11. self.act = act
  12. self.support = placeholders['support']
  13. self.sparse_inputs = sparse_inputs
  14. self.featureless = featureless
  15. self.bias = bias
  16. # helper variable for sparse dropout
  17. self.num_features_nonzero = placeholders['num_features_nonzero']
  18. # 下面是定义变量,主要是通过调用utils.py中的glorot函数实现
  19. with tf.variable_scope(self.name + '_vars'):
  20. for i in range(len(self.support)):
  21. self.vars['weights_' + str(i)] = glorot([input_dim, output_dim],
  22. name='weights_' + str(i))
  23. if self.bias:
  24. self.vars['bias'] = zeros([output_dim], name='bias')
  25. if self.logging:
  26. self._log_vars()
  27. def _call(self, inputs):
  28. x = inputs
  29. # dropout 设置dropout
  30. if self.sparse_inputs:
  31. x = sparse_dropout(x, 1-self.dropout, self.num_features_nonzero)
  32. else:
  33. x = tf.nn.dropout(x, 1-self.dropout)
  34. # convolve 卷积的实现。主要是根据论文中公式Z = \tilde{D}^{-1/2}\tilde{A}^{-1/2}X\theta实现
  35. supports = list()
  36. for i in range(len(self.support)):
  37. if not self.featureless:
  38. pre_sup = dot(x, self.vars['weights_' + str(i)],
  39. sparse=self.sparse_inputs)
  40. else:
  41. pre_sup = self.vars['weights_' + str(i)]
  42. support = dot(self.support[i], pre_sup, sparse=True)
  43. supports.append(support)
  44. output = tf.add_n(supports)
  45. # bias
  46. if self.bias:
  47. output += self.vars['bias']
  48. return self.act(output)

model.py

定义了一个model基类,以及两个继承自model类的MLP、GCN类。重点来看看GCN类的定义

  1. class GCN(Model):
  2. def __init__(self, placeholders, input_dim, **kwargs):
  3. super(GCN, self).__init__(**kwargs)
  4. self.inputs = placeholders['features']
  5. self.input_dim = input_dim
  6. # self.input_dim = self.inputs.get_shape().as_list()[1] # To be supported in future Tensorflow versions
  7. self.output_dim = placeholders['labels'].get_shape().as_list()[1]
  8. self.placeholders = placeholders
  9. self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
  10. self.build()
  11. # 损失计算
  12. def _loss(self):
  13. # Weight decay loss # 正则化项
  14. for var in self.layers[0].vars.values():
  15. self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
  16. # Cross entropy error # 交叉熵损失函数
  17. self.loss += masked_softmax_cross_entropy(self.outputs, self.placeholders['labels'],
  18. self.placeholders['labels_mask'])
  19. # 计算模型准确度
  20. def _accuracy(self):
  21. self.accuracy = masked_accuracy(self.outputs, self.placeholders['labels'],
  22. self.placeholders['labels_mask'])
  23. # 构建模型:两层GCN
  24. def _build(self):
  25. self.layers.append(GraphConvolution(input_dim=self.input_dim,
  26. output_dim=FLAGS.hidden1,
  27. placeholders=self.placeholders,
  28. act=tf.nn.relu,
  29. dropout=True,
  30. sparse_inputs=True,
  31. logging=self.logging))
  32. self.layers.append(GraphConvolution(input_dim=FLAGS.hidden1,
  33. output_dim=self.output_dim,
  34. placeholders=self.placeholders,
  35. act=lambda x: x,
  36. dropout=True,
  37. logging=self.logging))
  38. # 模型预测
  39. def predict(self):
  40. return tf.nn.softmax(self.outputs)

2 实践

更新中...

GCN代码分析 2019.03.12 22:34:54字数 560阅读 5714 本文主要对GCN源码进行分析。的更多相关文章

  1. 自定义View系列教程02--onMeasure源码详尽分析

    深入探讨Android异步精髓Handler 站在源码的肩膀上全解Scroller工作机制 Android多分辨率适配框架(1)- 核心基础 Android多分辨率适配框架(2)- 原理剖析 Andr ...

  2. MongoDB源码分析——mongod程序源码入口分析

    Edit 说明:第一次写笔记,之前都是看别人写的,觉得很简单,开始写了之后才发现真的很难,不知道该怎么分析,这篇文章也参考了很多前辈对MongoDB源码的分析,也有一些自己的理解,后续将会继续分析其他 ...

  3. FFmpeg的HEVC解码器源码简单分析:解析器(Parser)部分

    ===================================================== HEVC源码分析文章列表: [解码 -libavcodec HEVC 解码器] FFmpeg ...

  4. JUC同步器框架AbstractQueuedSynchronizer源码图文分析

    JUC同步器框架AbstractQueuedSynchronizer源码图文分析 前提 Doug Lea大神在编写JUC(java.util.concurrent)包的时候引入了java.util.c ...

  5. Proxy Server源码及分析(TCP Proxy源码 Socket实现端口映射)

    版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/u014530704/article/de ...

  6. 自定义View系列教程03--onLayout源码详尽分析

    深入探讨Android异步精髓Handler 站在源码的肩膀上全解Scroller工作机制 Android多分辨率适配框架(1)- 核心基础 Android多分辨率适配框架(2)- 原理剖析 Andr ...

  7. Spring Ioc源码分析系列--Ioc源码入口分析

    Spring Ioc源码分析系列--Ioc源码入口分析 本系列文章代码基于Spring Framework 5.2.x 前言 上一篇文章Spring Ioc源码分析系列--Ioc的基础知识准备介绍了I ...

  8. 【lwip】09-IPv4协议&超全源码实现分析

    目录 前言 9.1 IP协议简述 9.2 IP地址分类 9.2.1 私有地址 9.2.2 受限广播地址 9.2.3 直接广播地址 9.2.4 多播地址 9.2.5 环回地址 9.2.6 本地链路地址 ...

  9. MapReduce的ReduceTask任务的运行源码级分析

    MapReduce的MapTask任务的运行源码级分析 这篇文章好不容易恢复了...谢天谢地...这篇文章讲了MapTask的执行流程.咱们这一节讲解ReduceTask的执行流程.ReduceTas ...

随机推荐

  1. Python中二维数组的创建

    习惯了java的Matrix = [][]不知道python怎么创二维数组. 先看 python中的二维数组操作 对最后提出的二维数组创建方式存在疑问 Matrix = [([0] * 3) for ...

  2. Qt 字符串QString arg()用法总结

    1.QString::arg()//用字符串变量参数依次替代字符串中最小数值 QString i = "iTest";           // current file's nu ...

  3. Java 比较两个字符串的相似度算法(Levenshtein Distance)

    转载自: https://blog.csdn.net/JavaReact/article/details/82144732 算法简介: Levenshtein Distance,又称编辑距离,指的是两 ...

  4. react 闲谈 之 JSX

    jsx元素-> React.createElement -> 虚拟dom对象 -> render方法 1.在react中想将js当作变了引入到jsx中需要使用{} 2.在jsx中,相 ...

  5. 前端知识点回顾之重点篇——CORS

    CORS(cross origin resource sharing)跨域资源共享 来源:http://www.ruanyifeng.com/blog/2016/04/cors.html 它允许浏览器 ...

  6. Android:JNA实践(附Demo)

    一.JNA和JNI的对比   1.JNI的调用流程 Android应用开发中要实现Java和C,C++层交互时,想必首先想到的是JNI,但是JNI的使用过程十分繁琐,需要自己再封装一层JNI接口进行转 ...

  7. 用第三方工具类,将JavaBean、List、Map<String,Object>转成JSON文本

    导入第三方jar包: >commons-beanutils-1.7.0.jar >commons-collections-3.1.jar >commons-lang-2.5.jar ...

  8. Java端使用Batik将SVG转为PNG

    在上篇中,我们需要将Highcharts生成的图通过后台保存到pdf文件中,就需要对SVG进行转换. 这里就介绍一下使用Batik处理SVG代码的方法. 首先是jar包的获取地址,https://xm ...

  9. "挡位"还是"档位",究竟谁错了

    http://baijiahao.baidu.com/s?id=1581395663965196858&wfr=spider&for=pc 对于“挡”与“档”两个字,我一直并没有给以太 ...

  10. kettle转换和任务的基本使用

    0 创建转换 并保存0816_em.ktr 1 主对象树中选择DB连接,创建2个DB连接 2 创建表输入 核心对象树中选择输入>表输入 3 核心对象树中选择输出>插入/更新表 并连线 4 ...