AI

GCN 基础:GraphConv、GATConv、SAGEConv 的实现(PyG+DGL)

简介 GraphConv、GATConv 和 SAGEConv 是三种常用的图卷积层,功能都是类似的,用来学习图结构数据中的节点表示,以便于后续的图分析任务,比如说节点分类、图分类或链接预测等等。 三者的核心区别在于怎么聚合邻接节点的信息:GraphConv 采用平均池化,GATConv 通过注意力机制赋予不同邻居不同的重要性,而 SAGEConv 则提供了多种聚合函数选择。这些差异影响了导致训练出来的模型有不同的表现。 使用示例 在用法上都是类似的。一般来说使用 GATConv 我们会比较关注注意力头数,使用 SAGEConv 我们会比较关注聚合方式。 1import dgl.nn as dglnn 2 3# GraphConv 4conv1 = dglnn.GraphConv(in_feats, out_feats) 5x = conv1(g, x) 6 7# GATConv 8conv2 = dglnn.GATConv(in_feats, out_feats, num_heads) 9x = conv2(g, x) 10 11# SAGEConv 12conv3 = dglnn.SAGEConv(in_feats, out_feats, 'mean') 13x = conv3(g, x) 消息传递范式 设 $x_v\in\mathbb{R}^{d_1}$ 为节点 $v$ 的特征, $w_{e}\in\mathbb{R}^{d_2}$ 为边 $(u, v)$ 的特征。消息传递范式定义了如下的节点和边的计算过程: $$ \text{边计算: } m_{e}^{(t+1)} = \phi \left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \right) , (u, v, e) \in \mathcal{E} $$ $$ \text{节点计算: } x_v^{(t+1)} = \psi \left(x_v^{(t)}, ho\left(\left\lbrace m_{e}^{(t+1)} : (u, v, e) \in \mathcal{E} \right\rbrace \right)\right) $$ 其中, $\phi$ 是一个消息函数, 它根据边的特征和相邻节点的特征生成消息; $\psi$ 是一个更新函数, 它根据节点的当前特征和来自邻居的消息来更新节点的特征, 其中 $ho$ 是一个聚合函数。 Read more...
1 of 1