[论文翻译]基于端到端注意力机制的图学习方法


原文地址:https://arxiv.org/pdf/2402.10793v2


An end-to-end attention-based approach for learning on graphs

基于端到端注意力机制的图学习方法

Abstract

摘要

There has been a recent surge in transformer-based architectures for learning on graphs, mainly motivated by attention as an effective learning mechanism and the desire to supersede handcrafted operators characteristic of message passing schemes. However, concerns over their empirical effectiveness, s cal ability, and complexity of the pre-processing steps have been raised, especially in relation to much simpler graph neural networks that typically perform on par with them across a wide range of benchmarks. To tackle these shortcomings, we consider graphs as sets of edges and propose a purely attention-based approach consisting of an encoder and an attention pooling mechanism. The encoder vertically interleaves masked and vanilla self-attention modules to learn an effective representations of edges, while allowing for tackling possible mis specifications in input graphs. Despite its simplicity, the approach outperforms fine-tuned message passing baselines and recently proposed transformer-based methods on more than 70 node and graph-level tasks, including challenging long-range benchmarks. Moreover, we demonstrate state-of-the-art performance across different tasks, ranging from molecular to vision graphs, and he t ero philo us node classification. The approach also outperforms graph neural networks and transformers in transfer learning settings, and scales much better than alternatives with a similar performance level or expressive power.

近年来,基于Transformer的图学习架构激增,主要驱动力在于注意力机制作为一种高效学习方式,以及取代手工设计消息传递算子的需求。然而,这些架构在实证效果、可扩展性和预处理复杂度等方面存在争议,尤其是在性能与更简单的图神经网络(GNN)相当的广泛基准测试中。为解决这些缺陷,我们将图视为边集合,提出了一种纯注意力架构,包含编码器和注意力池化机制。该编码器通过垂直堆叠掩码自注意力与常规自注意力模块,既能学习有效的边表征,又能处理输入图可能存在的错误定义。尽管结构简单,该方法在70多项节点级和图级任务(包括具有挑战性的长程基准测试)中超越了精细调优的消息传递基线和近期提出的Transformer方法。此外,我们在从分子图到视觉图、异配节点分类等不同任务中实现了最先进性能。在迁移学习场景下,该方法同样优于图神经网络和Transformer模型,且在同等性能或表达能力下展现出更优的扩展性。

1 Introduction

1 引言

We investigate empirically the potential of a purely attention-based approach for learning effective representations of graph structured data. Typically, learning on graphs is modelled as message passing – an iterative process that relies on a message function to aggregate information from a given node’s neighbourhood and an update function to incorporate the encoded message into the output representation of the node. The resulting graph neural networks (GNNs) typically stack multiple such layers to learn node representations based on vertex rooted sub-trees, essentially mimicking the one-dimensional Weisfeiler–Lehman (1-WL) graph isomorphism test [1, 2]. Variations of message passing have been applied effectively in different fields such as life sciences [3–8], electrical engineering [9], and weather prediction [10].

我们通过实证研究探讨了纯基于注意力(attention)的方法在学习图结构数据有效表征方面的潜力。通常,图学习被建模为消息传递(message passing)——这是一个依赖消息函数聚合给定节点邻域信息,并利用更新函数将编码信息整合到节点输出表征中的迭代过程。由此产生的图神经网络(GNNs)通常会堆叠多个此类层,基于顶点根子树学习节点表征,本质上是在模拟一维Weisfeiler-Lehman(1-WL)图同构测试[1,2]。消息传递的变体已成功应用于生命科学[3-8]、电气工程[9]和天气预报[10]等不同领域。

Despite the overall success and wide adoption of graph neural networks, several practical challenges have been identified over time. While the message passing framework is highly flexible, the design of new layers is a challenging research problem where improvements take years to achieve and often rely on hand-crafted operators. This is particularly the case for general purpose graph neural networks that do not exploit additional input modalities such as atomic coordinates. For instance, principal neighbourhood aggregation (PNA) is regarded as one of the most powerful message passing layers [11], but it is built using a collection of manually selected neighbourhood aggregation functions, requires a dataset degree histogram which must be pre-computed prior to learning, and further uses manually selected degree scaling. The nature of message passing also imposes certain limitations which have shaped the majority of the literature. One of the most prominent examples is the readout function used to combine node-level features into a single graph-level representation, and which is required to be permutation invariant with respect to the node order. Thus, the default choice for graph neural networks and even graph transformers remains a simple, non-learnable function such as sum, mean, or max [12–14]. Limitations of this approach have been identified by Wagstaff et al. [15], who have shown that simple readout functions might require complex item embedding functions that are difficult to learn using standard neural networks. Additionally, graph neural networks have shown limitations in terms of over-smoothing [16], linked to node representations becoming similar with increased depth, and over-squashing [17] due to information compression through bottleneck edges. The proposed solutions typically take the form of message regular is ation schemes [18–20]. However, there is generally no consensus on the right architectural choices for building effective deep message passing neural networks. Transfer learning and strategies like pre-training and fine-tuning are also less ubiquitous in graph neural networks because of modest or ambiguous benefits, as opposed to large language models [21].

尽管图神经网络(Graph Neural Networks)取得了整体成功并得到广泛应用,但随着时间的推移也暴露出若干实际挑战。虽然消息传递框架具有高度灵活性,但新层的设计仍是艰巨的研究难题,其改进往往需要数年时间且通常依赖手工设计的算子。这种情况在不利用原子坐标等额外输入模态的通用图神经网络中尤为明显。例如,主邻域聚合(PNA)被视为最强大的消息传递层之一[11],但它采用了一组人工选择的邻域聚合函数,需要预先计算数据集度直方图,并进一步使用人工选择的度缩放参数。消息传递的本质特性也带来了一些限制,这些限制塑造了该领域的大部分研究。最典型的例子是将节点级特征合并为图级表征的读出函数(readout function),该函数必须对节点顺序保持排列不变性。因此,图神经网络甚至图Transformer的默认选择仍是简单的不可学习函数,如求和、均值或最大值[12-14]。Wagstaff等人[15]指出这种方法的局限性,证明简单读出函数可能需要复杂的项嵌入函数,而标准神经网络难以学习这种函数。此外,图神经网络还存在过平滑(over-smoothing)[16]和过挤压(over-squashing)[17]的局限:前者随深度增加导致节点表征趋同,后者源于瓶颈边的信息压缩。现有解决方案通常采用消息正则化方案[18-20],但对于构建高效深度消息传递神经网络的架构选择仍缺乏共识。与大语言模型[21]不同,由于收益有限或效果不明确,迁移学习及预训练-微调等策略在图神经网络中的应用也相对较少。

The attention mechanism [22] is one of the main sources of innovation within graph learning, either by directly incorporating attention within message passing [23, 24], by formulating graph learning as a language processing task [25, 26], or by combining vanilla GNN layers with attention layers [12, 13, 27, 28]. However, several concerns have been raised regarding the performance, s cal ability, and complexity of such methods.

注意力机制 [22] 是图学习创新的主要来源之一,其实现方式包括:直接在消息传递中引入注意力 [23, 24]、将图学习建模为语言处理任务 [25, 26],或将传统 GNN 层与注意力层结合 [12, 13, 27, 28]。然而,这类方法在性能、可扩展性和复杂度方面仍存在争议。

Performance-wise, recent reports indicate that sophisticated graph transformers under perform compared to simple but tuned GNNs [29, 30]. This line of work highlights the importance of empirically evaluating new methods relative to strong baselines. Separately, recent graph transformers have focused on increasingly more complex helper mechanisms, such as computationally expensive pre-processing and learning steps [25], various different encodings (e.g., positional, structural, and relational) [25–27, 31], inclusion of virtual nodes and edges [25, 28], conversion of the problem to natural language processing [25, 26], and other non-trivial graph transformations [28, 31]. These complications can significantly increase the computational requirements, reducing the chance of being widely adopted and replacing GNNs.

性能方面,近期报告表明,复杂的图Transformer模型表现不如经过调优的简单图神经网络(GNN) [29, 30]。这项工作强调了新方法相对于强基线进行实证评估的重要性。另一方面,近期的图Transformer研究聚焦于日益复杂的辅助机制,例如计算成本高昂的预处理和学习步骤 [25]、多种编码方式(如位置编码、结构编码和关系编码)[25–27, 31]、引入虚拟节点和边 [25, 28]、将问题转化为自然语言处理任务 [25, 26] 以及其他非平凡的图变换操作 [28, 31]。这些复杂性会显著增加计算需求,降低其被广泛采用并取代GNN的可能性。

Motivated by the effectiveness of attention as a learning mechanism and recent advances in efficient and exact attention, we introduce an end-to-end attention-based architecture for learning on graphs that is simple to implement, scalable, and achieves state-of-the-art results. The proposed architecture considers graphs as sets of edges, leveraging an encoder that interleaves masked and self-attention mechanisms to learn effective representations. The attention-based pooling component mimics the functionality of a readout function and it is responsible for aggregating the edge-level features into a permutation invariant graph-level representation. The masked attention mechanism allows for learning effective edge representations originating from the graph connectivity and the combination with self-attention layers vertically allows for expanding on this information while having a strong prior. Masking can, thus, be seen as leveraging specified relational information and its vertical combination with self-attention as a means to overcome possible mis specification of the input graph. The masking operator is injected into the pairwise attention weight matrix and allows only for attention between linked primitives. For a pair of edges, the connectivity translates to having a shared node between them. We focus primarily on learning through edge sets due to empirically high performance, and refer to our architecture as edge-set attention (ESA). We also demonstrate that the overall architecture is effective for propagating information across nodes through dedicated node-level benchmarks. The ESA architecture is general purpose, in the sense that it only relies on the graph structure and possibly node and edge features, and it is not restricted to any particular domain. Furthermore, ESA does not use positional, structural, relative, or similar encodings, it does not encode graph structures as tokens or other language (sequence) specific concepts, and it does not require any pre-computations.

受到注意力机制作为学习机制的有效性以及高效精确注意力最新进展的启发,我们提出了一种基于注意力的端到端图学习架构。该架构实现简单、可扩展性强,并能取得最先进的性能表现。我们提出的方法将图视为边集合,通过交替使用掩码注意力和自注意力机制的编码器来学习有效表征。基于注意力的池化组件模拟了读出函数的功能,负责将边级别特征聚合为排列不变的图级别表征。

掩码注意力机制能够基于图连通性学习有效的边表征,与自注意力层的纵向组合则能在保持强先验的前提下扩展信息。因此,掩码操作可视为利用特定关系信息的手段,其与自注意力的纵向结合能够克服输入图可能存在的错误设定。掩码算子被注入成对注意力权重矩阵,仅允许相连图元之间建立注意力连接。对于边对而言,连通性表现为两者共享一个节点。由于实证研究显示边集合学习具有卓越性能,我们主要聚焦于此方法,并将该架构命名为边集注意力(ESA)。通过专门的节点级基准测试,我们还证明了该架构在跨节点信息传播方面的有效性。

ESA架构具有通用性,仅依赖于图结构及可能的节点/边特征,不受特定领域限制。此外,ESA不使用位置编码、结构编码、相对编码等附加信息,不将图结构编码为token或其他语言(序列)特定概念,也无需任何预计算步骤。

Despite its apparent simplicity, ESA-based learning overwhelmingly outperforms strong and tuned GNN baselines and much more involved transformer-based models. Our evaluation is extensive, totalling 70 datasets and benchmarks from different domains such as quantum mechanics, molecular docking, physical chemistry, biophysics, bioinformatics, computer vision, social networks, functional call graphs, and synthetic graphs. At the node level, we include both homo philo us and he t ero philo us graph tasks, as well as shortest path problems, as these require modelling long-range interactions. Beyond supervised learning tasks, we explore the potential for transfer learning in the context of drug discovery and quantum mechanics [32] and show that ESA is a viable transfer learning strategy compared to vanilla GNNs and graph transformers.

尽管看似简单,基于ESA(Ego-graph Subgraph Adaptation)的学习方法显著优于经过调优的GNN基线模型和更复杂的基于Transformer的模型。我们的评估范围广泛,涵盖来自量子力学、分子对接、物理化学、生物物理学、生物信息学、计算机视觉、社交网络、函数调用图和合成图等不同领域的70个数据集和基准测试。在节点级别,我们同时包含同配性(homo philo us)和异配性(he t ero philo us)图任务,以及最短路径问题,因为这些任务需要对长程交互进行建模。除了监督学习任务外,我们还探索了在药物发现和量子力学[32]背景下迁移学习的潜力,并表明与普通GNN和图Transformer相比,ESA是一种可行的迁移学习策略。

2 Related Work

2 相关工作

An attention mechanism that mimics message passing and limits the attention computations to neighbouring nodes has been first proposed in GAT [23]. We consider masking as an abstraction of the GAT attention operator that allows for building general purpose relational structures between items in a set (e.g., $k$ -hop neighbourhoods or conformation al masks extracted from 3D molecular structures). The attention mechanism in GAT is implemented as a single linear projection matrix that does not explicitly distinguish between keys, queries, and values as in standard dot product attention and an additional linear layer after concatenating the representations of connected nodes, along with a non-linearity. This type of simplified attention has been labelled by subsequent work as static and was shown to have limited expressive power [24]. Brody et al. instead proposed dynamic attention, a simple reordering of operations in GAT, resulting in a more expressive GATv2 model [24] – however, at the price of doubling the parameter count and the corresponding memory consumption. A high-level overview of GAT in the context of masked attention is provided in Figure 1, along with the main differences to the proposed architecture. Similar to GAT, several adaptations of the original scaled dot product attention have been proposed for graphs [33, 34], where the focus was on defining an attention mechanism constrained by the node connectivity and replacing the positional encodings of the original transformer model with more appropriate graph alternatives, such as Laplacian ei gen vectors. These approaches, while interesting and forward-looking, did not convincingly outperform simple GNNs. Building up on this line of work, an architecture that can be seen as an instance of masked transformers has been proposed in SAN [35], illustrated in Figure 2. Attention coefficients there are defined as a convex combination of the scores (controlled by hyper-parameter $\gamma$ ) associated with the original graph and its complement. Min et al. [36] have also considered a masking mechanism for standard transformers. However, the graph structure itself is not used directly in the masking process as the devised masks correspond to four types of prior interaction graphs (induced, similarity, cross neighbourhood, and complete sub-graphs), acting as an inductive bias. Furthermore, the method was not designed for general purpose graph learning as it relies on helper mechanisms such as neighbour sampling and a heterogeneous information network.

一种模仿消息传递并将注意力计算限制在相邻节点上的注意力机制最初在GAT [23]中被提出。我们将掩码视为GAT注意力算子的一种抽象,它允许在集合中的项目之间构建通用的关系结构(例如,$k$-跳邻域或从3D分子结构中提取的构象掩码)。GAT中的注意力机制被实现为单一的线性投影矩阵,不像标准点积注意力那样明确区分键、查询和值,并且在连接相连节点的表示后还有一个额外的线性层和非线性激活。这种简化的注意力被后续工作标记为静态的,并被证明表达能力有限[24]。Brody等人提出了动态注意力,这是对GAT中操作顺序的简单重新排列,从而产生了更具表达力的GATv2模型[24]——然而,代价是参数数量和相应的内存消耗翻倍。图1提供了在掩码注意力背景下GAT的高层次概述,以及与所提出架构的主要区别。与GAT类似,针对图结构提出了几种原始缩放点积注意力的改编版本[33, 34],重点是定义受节点连接性约束的注意力机制,并用更合适的图替代方案(如拉普拉斯特征向量)替换原始Transformer模型的位置编码。这些方法虽然有趣且具有前瞻性,但并未明显超越简单的GNN。基于这一系列工作,SAN [35]提出了一种可被视为掩码Transformer实例的架构,如图2所示。其中的注意力系数被定义为与原始图及其补图相关的分数(由超参数$\gamma$控制)的凸组合。Min等人[36]也考虑了标准Transformer的掩码机制。然而,图结构本身并未直接用于掩码过程,因为设计的掩码对应于四种先验交互图(诱导图、相似图、交叉邻域图和完全子图),作为归纳偏置。此外,该方法并非为通用图学习而设计,因为它依赖于邻居采样和异构信息网络等辅助机制。


Figure 1: A high level overview of the GAT message passing algorithm [23]. Panel A: A GAT layer receives node representations and the adjacency matrix as inputs. First a projection matrix is applied to all the nodes, which are then concatenated pairwise and passed through another linear layer, followed by a non-linearity. The final attention score is computed by the softmax function. Panel B: A GAT model stacks multiple GAT layers and uses a readout function over nodes to generate a graph-level representation. Residual connections are also illustrated as a modern enhancement of GNNs (not included in the original approach). GAT vs ESA: Our masked self-attention module relies on masking within classical scaled dot product attention which comes with different projection matrices for keys, queries, and values. Additionally, we wrap the masked scaled dot product attention with layer normalization and skip connections, the output of which is passed through an MLP. The latter is not done classically in GAT as part of its attention modules nor in the overall architecture. In contrast to GAT that operates over nodes, our masking operator is defined over sets of edges that are connected if they share a node in common. When it comes to the readout, GAT uses sum, mean or max and aggregates over nodes whereas our architecture relies on pooling by multi-head attention and aggregation over learned representations of seed vectors inherent to that module. In terms of the encoder that is responsible for learning representations of set items, ours interleaves masked and self-attention layers vertically whereas in GATs the stacked layers are entirely based on neighbourhood attention.

图 1: GAT消息传递算法的高层概览 [23]。A部分: GAT层接收节点表示和邻接矩阵作为输入。首先对所有节点应用投影矩阵,然后成对拼接并通过另一个线性层,接着经过非线性变换。最终注意力分数由softmax函数计算得出。B部分: GAT模型堆叠多个GAT层,并使用节点上的读出函数生成图级表示。残差连接作为GNN的现代增强技术也被展示(原始方法中未包含)。GAT与ESA对比: 我们的掩码自注意力模块基于经典缩放点积注意力中的掩码机制,并为键、查询和值使用不同的投影矩阵。此外,我们用层归一化和跳跃连接封装掩码缩放点积注意力,其输出通过MLP传递。这一设计既不是GAT注意力模块的经典做法,也不属于其整体架构。与作用于节点的GAT不同,我们的掩码算子定义在共享同一节点的边集合上。在读出阶段,GAT使用求和、均值或最大值对节点进行聚合,而我们的架构依赖于通过多头注意力池化,并聚合该模块固有种子向量的学习表示。就负责学习集合项表示的编码器而言,我们的方法垂直交错掩码层和自注意力层,而GAT的堆叠层完全基于邻域注意力。

Recent trends in learning on graphs are dominated by architectures based on standard self-attention layers as the only learning mechanism, with a significant amount of effort put into representing the graph structure exclusively through positional and structural encodings. Graphormer [25] is one of the most prominent approaches from this class. It comes with an involved and computation-heavy suite of pre-processing steps, involving a centrality, spatial, and edge encoding. For instance, spatial encodings rely on the Floyd–Warshall algorithm that has cubic time complexity in the number of nodes and quadratic memory complexity. Also, the model employs a virtual mean-readout node that is connected to all other nodes in the graph. While Graphormer has originally been evaluated only on four datasets, the results were promising relative to GNNs and SAN. Another related approach is the Tokenized Graph Transformer (TokenGT) [26], which treats all the nodes and edges as independent tokens. To adapt sequence learning to the graph domain, TokenGT encodes the graph information using node identifiers derived from orthogonal random features or Laplacian ei gen vectors, and learnable type identifiers for nodes and edges. TokenGT is provably more expressive than standard GNNs, and can approximate $k$ -WL tests with the appropriate architecture (number of layers, adequate pooling, etc.). However, this theoretical guarantee holds only with the use of positional encodings that typically break the permutation invariance over nodes that is required for consistent predictions over the same graph presented with a different node order. The main strength of TokenGT is its theoretical expressiveness, as it has only been evaluated on a single dataset where it did not outperform Graphormer.

近年来,图学习领域的主流架构主要基于标准自注意力(self-attention)层作为唯一学习机制,并通过大量工作将图结构信息仅通过位置编码和结构编码来表示。Graphormer [25]是该类方法中最具代表性的方案之一,它采用了一套复杂且计算量大的预处理流程,包括中心性编码、空间编码和边编码。例如,其空间编码依赖于Floyd-Warshall算法,该算法具有节点数量的三次方时间复杂度和二次方内存复杂度。此外,该模型还引入了一个与图中所有节点相连的虚拟平均读出节点。虽然Graphormer最初仅在四个数据集上进行评估,但其结果相比图神经网络(GNNs)和SAN展现出优势。另一相关工作是Token化图Transformer(TokenGT) [26],该方法将所有节点和边视为独立token。为适应图域序列学习,TokenGT采用基于正交随机特征或拉普拉斯特征向量的节点标识符,以及可学习的节点/边类型标识符来编码图信息。理论证明TokenGT比标准GNNs更具表达能力,并能通过适当架构(层数、池化方式等)逼近$k$-WL测试。但该理论保证仅在使用了通常会破坏节点排列不变性的位置编码时成立,而这种不变性是对同一图结构在不同节点顺序下保持预测一致性所必需的。TokenGT的主要优势在于理论表达能力,目前仅在单个数据集上评估且未超越Graphormer。

A route involving a hybrid between classical transformers and message passing has been pursued in the GraphGPS framework [37], which combines message passing and transformer layers. As in previous works, GraphGPS puts a large emphasis on different types of encodings, proposing and analysing positional and structural encodings, further divided into local, global, and relative encodings. Exphormer [28] is an evolution of GraphGPS that adds virtual global nodes and sparse attention based on expander graphs. While effective, such frameworks do still rely on message passing and are thus not purely attention-based solutions. Limitations include dependence on approximations (Performer [38] for both, expander graphs for Exphormer), and decreased performance when encodings (GraphGPS) or special nodes (Exphormer) are removed. Notably, Exphormer is the first approach from this class to consider custom attention patterns given by node neighbourhoods.

图GPS框架[37]探索了一条结合经典Transformer与消息传递的混合路径,该框架整合了消息传递层和Transformer层。与先前研究类似,图GPS重点关注多种编码类型,提出并分析了位置编码与结构编码(进一步细分为局部、全局和相对编码)。Exphormer[28]作为图GPS的改进版本,引入了基于扩展图的虚拟全局节点和稀疏注意力机制。尽管效果显著,这类框架仍依赖消息传递机制,并非纯粹的注意力解决方案。其局限性包括对近似方法的依赖(图GPS使用Performer[38],Exphormer采用扩展图),以及当移除编码(图GPS)或特殊节点(Exphormer)时性能下降。值得注意的是,Exphormer是该类方法中首个考虑节点邻域定制注意力模式的方案。

3 Methods

3 方法

In this section, we provide a high-level overview of our architecture and describe the masked attention module that allows for information propagation across edges as primitives. An alternative implementation for the more traditional node-based propagation is also presented. The section starts with a formal definition of the masked self-attention mechanism and then proceeds to describe our end-to-end attention-based approach for learning on graphs, involving the encoder and pooling components (illustrated in Figure 3).

在本节中,我们将从高层次概述架构,并描述以边为基元实现跨边信息传播的掩码注意力模块。同时提供了基于传统节点传播的替代实现方案。本节首先对掩码自注意力机制进行形式化定义,随后阐述我们基于注意力机制的图学习端到端方法,该方法包含编码器和池化组件(如图 3 所示)。


Figure 2: A high level overview of the SAN architecture [35]. Panel A: Two independent attention modules with a tied/shared value projection matrix are used for the input graph and its complement. The outputs of these two modules are convexly combined using a hyper parameter $\gamma$ before being residually added to the input. Panel B: The overall SAN architecture that stacks multiple SAN layers without residual connections, and uses a standard readout function over nodes. SAN vs ESA: What is truly different in relation to this architecture is horizontal versus vertical combination of masked and self-attention. While it is possible to set $\gamma=0$ for some layers and $\gamma=1$ for others, this avenue has never been explored in SAN and it would still result in a slightly different encoder as the query and key projections are different for the input graph and its complement. The empirical analysis in [35] also provides no insights relative to vertical combination of masked and self-attention, nor the impact of attention pooling which is a differentiating factor and an important part of our architecture.

图 2: SAN架构的总体概览 [35]。A部分:输入图及其补图使用两个独立的注意力模块,共享值投影矩阵。这两个模块的输出通过超参数$\gamma$进行凸组合后,以残差方式添加到输入中。B部分:整体SAN架构由多个无残差连接的SAN层堆叠而成,并在节点上使用标准读出函数。SAN与ESA的区别:该架构的真正差异在于掩码注意力和自注意力的横向组合与纵向组合方式。虽然可以为某些层设置$\gamma=0$,其他层设置$\gamma=1$,但SAN中从未探索过这种方案,且由于输入图与其补图的查询和键投影不同,仍会导致编码器存在细微差异。[35]中的实证分析也未涉及掩码注意力与自注意力的纵向组合,或注意力池化的影响(这是我们架构的差异化要素和重要组成部分)。


Figure 3: A high-level overview of Edge Set Attention and its main building blocks. Panel A: An illustration of masked scaled dot product attention with edge inputs. Edge features derived from the edge set are processed by query, key, and value projections as in standard attention. An additive edge mask, derived as in Algorithm 1, is applied prior to the softmax operator. A mask value of 0 leaves the attention score unchanged, while a large negative value completely discards that attention score. Panel B: The Pooling by Multi-Head Attention Block (PMA). A learnable set of seed tensors $\mathbf{S}_{k}$ is randomly initial is ed and used as the query for a cross attention operation with the inputs. The result is further processed by Self Attention Blocks. Panel C: The Self Attention Block (SAB), illustrated here with a pre-layer-normalization architecture. The input is normalised, processed by self attention, then further normalised and processed by an MLP. The block uses residual connections. Panel D: Illustration of the Masked (Self) Attention Block (MAB). The only difference compared to SAB is the edge mask, which follows the computation outlined in Panel A. Panel E: An instantiation of the ESA model. Edge inputs are processed by various different attention blocks, here in the order MSMSP, corresponding to masked, self, masked, self, and pooling attention layers.

图 3: 边集注意力机制及其核心组件的高层概览。A部分: 带边输入的掩码缩放点积注意力示意图。从边集提取的边特征通过标准注意力机制中的查询、键和值投影进行处理。如算法1所述生成的加性边掩码会在softmax算子前应用。掩码值为0时保持注意力分数不变,而较大负值会完全丢弃该注意力分数。B部分: 多头注意力池化块(PMA)。一组可学习的种子张量$\mathbf{S}_{k}$被随机初始化,作为跨注意力操作的查询输入。结果会通过自注意力块进一步处理。C部分: 自注意力块(SAB),此处展示的是层归一化前置架构。输入经过归一化、自注意力处理,再通过MLP进行二次归一化和处理。该模块采用残差连接。D部分: 掩码(自)注意力块(MAB)示意图。与SAB的唯一区别在于边掩码,其计算流程如A部分所述。E部分: ESA模型的具体实现。边输入会经过多种不同的注意力块处理,此处展示的顺序是MSMSP,对应掩码、自注意力、掩码、自注意力和池化注意力层。

3.1 Masked Attention Modules

3.1 掩码注意力模块

We first introduce the basic notation and then give a formal description of the masked attention mechanism with a focus on the connectivity pattern specific to graphs. Following this, we provide algorithms for an efficient implementation of the masking operators using native tensor operations.

我们首先介绍基本符号,然后重点描述针对图结构特定连接模式的掩码注意力机制 (masked attention mechanism) 形式化定义。随后提供基于原生张量运算高效实现掩码算子的算法。

As in most graph learning settings, a graph is a tuple $\mathcal{G}=(\mathcal{N},\mathcal{E})$ where $\mathcal{N}$ represents the set of nodes (vertices), $\mathcal{E}\subseteq\mathcal{N}\times\mathcal{N}$ is the set of edges, and $N_{n}=|\mathcal{N}|$ , $N_{e}=|\mathcal{E}|$ . Nodes are associated with feature vectors ${\bf\delta I}{i}$ of dimension $d_{n}$ for all nodes $i\in\mathcal N$ , and $d_{e}$ -dimensional edge features $\mathbf{e}{i j}$ for all edges $e\in\mathcal{E}$ . The node features are collected as rows in a matrix $\mathbf{N}\in\mathbb{R}^{N_{n}\times d_{n}}$ , and similarly for edge features into $\mathbf{E}\in\mathbb{R}^{N_{e}\times d_{e}}$ . The graph connectivity information can be represented as an adjacency matrix A, where $\mathbf{A}{i j}=1$ if $(i,j)\in\mathcal{E}$ and $\mathbf{A}_{i j}=0$ otherwise. The edge list (edge index) representation is equivalent but more common in practice.

与大多数图学习设定相同,图被定义为一个元组 $\mathcal{G}=(\mathcal{N},\mathcal{E})$ ,其中 $\mathcal{N}$ 表示节点(顶点)集合, $\mathcal{E}\subseteq\mathcal{N}\times\mathcal{N}$ 是边集合,且 $N_{n}=|\mathcal{N}|$ , $N_{e}=|\mathcal{E}|$ 。每个节点 $i\in\mathcal N$ 关联一个维度为 $d_{n}$ 的特征向量 ${\bf\delta I}{i}$ ,每条边 $e\in\mathcal{E}$ 关联一个 $d_{e}$ 维边特征 $\mathbf{e}{i j}$ 。节点特征按行收集为矩阵 $\mathbf{N}\in\mathbb{R}^{N_{n}\times d_{n}}$ ,边特征同理收集为 $\mathbf{E}\in\mathbb{R}^{N_{e}\times d_{e}}$ 。图的连接信息可用邻接矩阵A表示,当 $(i,j)\in\mathcal{E}$ 时 $\mathbf{A}{i j}=1$ ,否则 $\mathbf{A}_{i j}=0$ 。边列表(边索引)表示形式与之等价,但在实践中更为常见。

Algorithm 1: Edge masking in PyTorch Geometric (T is the transpose; helper functions are explained in SI 6).

算法 1: PyTorch Geometric中的边掩码实现(T表示转置;辅助函数说明见SI 6)

| 1 | fromesa import consecutive,first-unique-index |
| 2 | function edge-adjacency(batched_edge-index) | 17 | function edge_mask(b-ei,b_map,B,L) |
| 3 | N_e ←batched_edge_index.size(1) | 18 | mask←torch.full(size=(B,L,L),fill=False) |
| 4 | source_nodes←batched_edge_index[0] | 19 | edge_to-graph ← b_map.index_select(O,b_ei[0,:]) |
| 5 | target_nodes{batched_edge_index[1] | 20 | |
| 6 | | 21 | edge_adj←edge-adjacency(b_ei) |
| 7 | #unsqueeze and expand | 22 | ei_to_originalconsecutive( |
| 8 | exp-src ←source_nodes.unsq(1).expand((-1,N_e)) | 23 | first-unique-index(edge_to-graph),b-ei.size(1)) |
| 6 | exp-trg ←target-nodes.unsq(1).expand((-1,N_e)) | 24 | |
| 10 | | 25 | edges←edge_adj.nonzero() |
| 11 | src-adj← exp-src== T(exp-src) | 26 | graph-index ←edge-to-graph.idx_select(O,edges[:,0]) |
| 12 | trg-adj ← exp-trg== T(exp-trg) | 27 | coord_1←ei_to_original.idx_select(O,edges[:, ,0]) |
| 13 | cross← (exp_src== :T(exp-trg)) logical_or | 28 | coord_2 <ei_to_original.idx_select(0,edges[:,1]) |
| 14 | (exp-trg == T(exp-src)) | 29 | |
| 15 | | 30 | mask[graph-index,coord-1,coord_2]←True |
| 16 | return ( | 31 | return~mask |

The main building block in ESA is masked scaled dot product attention. Panel A in Figure 3 illustrates this attention mechanism schematically. More formally, this attention mechanism is given by

ESA中的主要构建模块是掩码缩放点积注意力 (masked scaled dot product attention)。图 3 中的面板 A 以示意图形式展示了这种注意力机制。更正式地说,该注意力机制由以下公式给出:

$$
\mathrm{SDPA}(\mathbf{Q},\mathbf{K},\mathbf{V},\mathbf{M})=\mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}+\mathbf{M}\right)\mathbf{V}
$$

$$
\mathrm{SDPA}(\mathbf{Q},\mathbf{K},\mathbf{V},\mathbf{M})=\mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}+\mathbf{M}\right)\mathbf{V}
$$

where $\mathbf{Q}$ , $\mathbf{K}$ , and $\mathbf{V}$ are projections of the edge representations to queries, keys, and values, respectively. This is a minor extension of the standard scaled dot product attention via the additive mask M that can censor any of the pairwise attention scores, and $d_{k}$ is the key dimension. The generalisation to masked multihead attention follows the same steps as in the original transformer [22]. Below and in the specification of the overall architecture, we refer to this function as MultiHead $({\bf Q},{\bf K},{\bf V},{\bf M})$ .

其中 $\mathbf{Q}$、$\mathbf{K}$ 和 $\mathbf{V}$ 分别是边表示向查询(query)、键(key)和值(value)的投影。这是对标准缩放点积注意力机制的一个小扩展,通过可加性掩码M来屏蔽任意成对注意力分数,$d_{k}$ 是键维度。带掩码的多头注意力机制的泛化遵循原始Transformer [22] 的相同步骤。在下文及整体架构的规范中,我们将此函数称为MultiHead $({\bf Q},{\bf K},{\bf V},{\bf M})$。

Masked self-attention for graphs can be seen as graph-structured attention, and an instance of a generalized attention mechanism with custom attention patterns [28]. More specifically, in ESA the attention pattern is given by an edge adjacency matrix rather than allowing for interactions between all set items. Crucially, the edge adjacency matrix can be efficiently computed both for a single graph and for batched graphs using exclusively tensor operations. The case for a single graph is covered in the left side of Algorithm 1 through the edge adjacency function. The first 3 lines of the function correspond to getting the number of edges, then separating the source and target nodes from the edge adjacency list (also called edge index), which is equivalent to the standard adjacency matrix of a graph. The source and target node tensors each have a dimension equal to the number of edges ( $N_{e}$ ). Lines 7-9 add an additional dimension and efficiently repeat (‘expand’) the existing tensors to shape $(N_{e},N_{e})$ without allocating new memory. Using the transpose, line 11 checks if the source nodes of any two edges are the same, and the same for target nodes on line 12. On line 13, cross connections are checked, where the source node of an edge is the target node of another, and vice versa. The operations of lines 11-14 result in boolean matrices of shape $(N_{e},N_{e})$ which are summed for the final returned edge adjacency matrix. The right panel of Algorithm 1 depicts the case where the input graph represents a batch of smaller graphs. This requires an additional batch mapping tensor that maps each node in the batched graph to its original graph, and carefully manipulating the indices to create the final edge adjacency mask of shape $(B,L,L)$ , where $L$ is the maximum number of edges in the batch.

图的掩码自注意力 (masked self-attention) 可视为图结构注意力,是采用自定义注意力模式的广义注意力机制实例 [28]。具体而言,在ESA中,注意力模式由边邻接矩阵而非所有集合项间的交互决定。关键在于,该边邻接矩阵可通过纯张量运算高效计算,适用于单图和批量图处理。算法1左侧展示了单图情况下的边邻接函数实现:前3行获取边数量,从边邻接列表(亦称边索引,等价于标准图邻接矩阵)分离源节点与目标节点张量,二者维度均等于边数 ($N_{e}$)。第7-9行通过扩展操作(不分配新内存)为张量新增维度并重塑为 $(N_{e},N_{e})$ 形状。第11行利用转置检查任意两条边的源节点是否相同,第12行同理处理目标节点。第13行验证交叉连接(某边的源节点为另一边的目标节点,反之亦然)。第11-14行操作生成 $(N_{e},N_{e})$ 布尔矩阵,经求和得到最终边邻接矩阵。算法1右侧描述批量图输入场景:需额外批次映射张量将节点映射至原图,并通过索引操作生成形状为 $(B,L,L)$ 的边邻接掩码($L$ 为批次中最大边数)。

Since in ESA attention is computed over edges, we chose to separate source and target node features for each edge, similarly to lines 4-5 of Algorithm $^{1}$ , and concatenate them to the edge features:

由于在ESA中注意力是在边上计算的,我们选择将每条边的源节点和目标节点特征分离,类似于算法1的第4-5行,并将它们与边特征拼接起来:

$$
\mathbf{x}{i j}=\mathbf{n}{i}\parallel\mathbf{n}{j}\parallel\mathbf{e}_{i j}
$$

$$
\mathbf{x}{i j}=\mathbf{n}{i}\parallel\mathbf{n}{j}\parallel\mathbf{e}_{i j}
$$

for each edge $e_{i j}$ . The resulting features $\mathbf{x}{i j}$ are collected in a matrix $\mathbf{X}\in\mathbb{R}^{N_{e}\times(2d_{n}+d_{e})}$ .

对于每条边 $e_{i j}$,最终得到的特征 $\mathbf{x}{i j}$ 被收集到矩阵 $\mathbf{X}\in\mathbb{R}^{N_{e}\times(2d_{n}+d_{e})}$ 中。

Having defined the mask generation process and the masked multihead attention function, we next define the modular blocks of ESA, starting with the Masked Self Attention Block (MAB):

在定义了掩码生成过程和掩码多头注意力函数后,我们接下来定义ESA的模块化构建块,从掩码自注意力块 (MAB) 开始:

$$
\begin{array}{r}{\mathrm{MAB}(\mathbf{X},\mathbf{M})=\mathbf{H}+\mathrm{MLP}(\mathrm{LayerNorm}(\mathbf{H}))\quad\quad}\ {\mathbf{H}=\overline{{\mathbf{X}}}+\mathrm{MultiHead}(\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\mathbf{M})\quad\quad}\ {\overline{{\mathbf{X}}}=\mathrm{LayerNorm}(\mathbf{X})\quad\quad}\end{array}
$$

$$
\begin{array}{r}{\mathrm{MAB}(\mathbf{X},\mathbf{M})=\mathbf{H}+\mathrm{MLP}(\mathrm{LayerNorm}(\mathbf{H}))\quad\quad}\ {\mathbf{H}=\overline{{\mathbf{X}}}+\mathrm{MultiHead}(\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\mathbf{M})\quad\quad}\ {\overline{{\mathbf{X}}}=\mathrm{LayerNorm}(\mathbf{X})\quad\quad}\end{array}
$$

where the mask M is computed as in Algorithm 1 and has shape $B\times L\times L$ , with $B$ the batch size, and MLP is a multi-layer perceptron. A Self Attention Block (SAB) can be formally defined as:

其中掩码M按算法1计算得出,形状为$B\times L\times L$(B为批次大小),MLP为多层感知机。自注意力块(SAB)可形式化定义为:

$$
\mathrm{SAB}(\mathbf{X})=\mathrm{MAB}(\mathbf{X},\mathbf{0})
$$

$$
\mathrm{SAB}(\mathbf{X})=\mathrm{MAB}(\mathbf{X},\mathbf{0})
$$

In principle, the masks can be arbitrary and a promising avenue of research could be in designing new mask

原则上,掩码可以是任意的,设计新型掩码可能是一个有前景的研究方向。

Algorithm 2: Node masking in the PyTorch Geometric framework.

算法 2: PyTorch Geometric框架中的节点掩码

行号 代码
1 from torch_geometric import unbatch_edge_index
2 function node_mask(batched_edge_index, batch_map, B, M)
3 # batched_edge_index将所有的edge_index张量批量处理为单个张量
4
5 # B, M分别是批次大小和掩码大小
6 mask ← torch.full(size=(B, M, M), fill=False)
7 graph_idx ← batch_map.index_select(0, batched_edge_index[0, :])
8 edge_index_list ← unbatch_edge_index(batched_edge_index, batch_map)
9 edge_index ← torch.cat(edge_index_list, dim=1)
10 mask[graph_idx, edge_index[0, :], edge_index[1, :]] ← True
11 return ~mask

types. For certain tasks such as node classification, we also defined an alternative version of ESA for nodes (NSA). The only major difference is the mask generation step. The single graph case is trivial as all the masking information is available in the edge index, so we provide the general batched case in Algorithm 2.

对于节点分类等特定任务,我们还定义了节点版ESA的替代方案(NSA)。主要区别仅在于掩码生成步骤。单图场景较为简单,因为所有掩码信息都存在于边索引中,因此我们在算法2中提供了通用的批处理案例。

3.2 ESA Architecture

3.2 ESA架构

Our architecture consists of two components: $\imath$ ) an encoder that interleaves masked and self-attention blocks to learn an effective representation of edges, and ii) a pooling block based on multi-head attention, inspired by the decoder component from the set transformer architecture [39]. The latter comes naturally when one considers graphs as sets of edges (or nodes), as done here. ESA is a purely attention-based architecture that leverages the scaled dot product mechanism proposed by Vaswani et al. [22] through masked- and standard self-attention blocks (MABs and SABs) and a pooling by multi-head attention (PMA) block.

我们的架构由两个组件组成:$\imath$) 一个通过交错掩码注意力块和自注意力块来学习边有效表示的编码器,以及 ii) 一个基于多头注意力的池化块,其灵感来自集合Transformer架构 [39] 的解码器组件。当我们将图视为边(或节点)的集合时,后者自然适用。ESA是一种纯基于注意力的架构,它通过掩码注意力块(MAB)、标准自注意力块(SAB)以及多头注意力池化块(PMA),利用了Vaswani等人 [22] 提出的缩放点积机制。

The encoder consisting of arbitrarily interleaved MAB and SAB blocks is given by:

由任意交错排列的 MAB 和 SAB 块组成的编码器表示为:

$$
\mathrm{B\circ\cdots(X,M),\in{MAB,SAB}}
$$

$$
\mathrm{B\circ\cdots(X,M),其中\in{MAB,SAB}}
$$

Here, ‘AB’ refers to an attention block that can be instantiated as an MAB or SAB.

此处,“AB”指代一个注意力块 (attention block),可实例化为MAB或SAB。

The pooling module that is responsible for aggregating the processed edge representations into a graph-level representation is formally defined by:

负责将处理后的边表征聚合为图级表征的池化模块正式定义为:

$$
\begin{array}{r l}{\mathrm{PMA}{k,p}(\mathbf{Z})=\mathrm{SAB}^{p}\left(\overline{{\mathbf{S}}}+\mathrm{MLP}(\overline{{\mathbf{S}}})\right)}\ {\qquad\overline{{\mathbf{S}}}=\mathrm{LayerNorm}\left(\mathrm{MultiHead}\left(\mathbf{S}_{k},\mathbf{Z},\mathbf{Z},\mathbf{0}\right)\right)}\end{array}
$$

$$
\begin{array}{r l}{\mathrm{PMA}{k,p}(\mathbf{Z})=\mathrm{SAB}^{p}\left(\overline{{\mathbf{S}}}+\mathrm{MLP}(\overline{{\mathbf{S}}})\right)}\ {\qquad\overline{{\mathbf{S}}}=\mathrm{LayerNorm}\left(\mathrm{MultiHead}\left(\mathbf{S}_{k},\mathbf{Z},\mathbf{Z},\mathbf{0}\right)\right)}\end{array}
$$

where $\mathbf{S}{k}$ is a tensor of $k$ learnable seed vectors that are randomly initial is ed and $\mathrm{SAB}^{p}(\cdot)$ is the application of $p$ SABs. Technically, it suffices to set $k=1$ to output a single representation for the entire graph. However, we have empirically found it beneficial to set it to a small value, such as $k=32$ . Moreover, this change allows self attention (SABs) to further process the $k$ resulting representations, which can be simply summed or averaged due to the small $k$ . Contrary to classical readouts that aggregate directly over set items (i.e., nodes), pooling by multi-head attention performs the final aggregation over the embeddings of learnt seed vectors $S_{k}$ . While tasks involving node-level predictions require only the Encoder component, the predictive tasks involving graph-level representations require all the modules, both the encoder and pooling by multi-head attention. The architecture in the latter setting is formally given by

其中 $\mathbf{S}{k}$ 是一个包含 $k$ 个可学习种子向量的张量,这些向量被随机初始化,而 $\mathrm{SAB}^{p}(\cdot)$ 表示应用 $p$ 个 SAB。技术上,设置 $k=1$ 足以输出整个图的单一表示。然而,我们通过实验发现将其设置为较小值(如 $k=32$)更为有利。此外,这一调整使得自注意力机制(SABs)能够进一步处理得到的 $k$ 个表示,由于 $k$ 值较小,这些表示可以直接求和或取平均。与直接在集合项(即节点)上进行聚合的经典读出方法不同,多头注意力池化是在学习到的种子向量 $S_{k}$ 的嵌入表示上执行最终聚合。涉及节点级预测的任务仅需编码器组件,而涉及图级表示的任务则需要所有模块,包括编码器和多头注意力池化。后一种场景下的架构形式化定义为

$$
\mathbf{Z}{\mathrm{out}}=\mathrm{PMA}_{k,p}(\mathrm{Encoder}(\mathbf{X},\mathbf{M})+\mathbf{X})
$$

$$
\mathbf{Z}{\mathrm{out}}=\mathrm{PMA}_{k,p}(\mathrm{Encoder}(\mathbf{X},\mathbf{M})+\mathbf{X})
$$

As optimal configurations are task specific, we do not explicitly fix all the architectural details. For example, it is possible to select between layer and batch normalisation, a pre-LN or post-LN architecture [40], or standard and gated MLPs, along with GLU variants, e.g., SwiGLU [41].

由于最优配置与任务相关,我们并未明确固定所有架构细节。例如,可以在层归一化与批量归一化之间选择,采用前置层归一化 (pre-LN) 或后置层归一化 (post-LN) 架构 [40],或选择标准MLP与门控MLP(如SwiGLU [41]等GLU变体)。

3.3 Time and Memory Scaling

3.3 时间和内存扩展

ESA is enabled by recent advances in efficient and exact attention, such as memory-efficient attention and Flash attention [42–45]. Flash attention does not yet support masking, but memory-efficient implementations with arbitrary masking capabilities exist in both PyTorch and xFormers [46, 47]. Theoretically, the memory complexity of memory-efficient approaches is $O(\sqrt{n})$ , where $n$ is the sequence/set size, with the same time complexity as standard attention. Since all operations in ESA except the cross-attention in PMA are based on self-attention, our method benefits directly from all of these advances. Moreover, our cross-attention is performed between the full set of size $n$ and a small set of $k\leq32$ learnable seeds, such that the complexity of the operation is $O(k n)$ even for standard attention. Flash attention, which has linear memory complexity even for self-attention and is up to $\times$ 25 times faster than standard attention, supports cross-attention and we use it for further uplifts. However, these implementations are not optimised for graph learning and we have found that the biggest bottleneck is not the attention computation, but rather storing the dense edge adjacency mask, which requires memory quadratic in the number of edges. An evident, but not yet available, optimisation could amount to storing the masks in a sparse tensor format. Alternatively, even a single boolean requires an entire byte of storage in PyTorch, increasing the theoretical memory usage by 8 times. Further limitations and possible optimisation s are discussed in SI 5. Despite some of these limitations, we have successfully trained ESA models for graphs with up to approximately 30,000 edges (e.g., dd in Table 6).

ESA得益于近年来高效精确注意力机制的进步,如内存高效注意力(memory-efficient attention)和Flash attention[42–45]。虽然Flash attention暂不支持掩码功能,但PyTorch和xFormers[46,47]已实现支持任意掩码操作的内存高效版本。理论上,内存高效方法的空间复杂度为$O(\sqrt{n})$(其中$n$表示序列/集合大小),时间复杂度则与标准注意力机制相同。由于ESA中除PMA模块的交叉注意力外均基于自注意力机制,我们的方法能直接受益于这些技术进步。此外,我们的交叉注意力操作在规模为$n$的完整集合与少量可学习种子($k\leq32$)之间进行,因此即便使用标准注意力,其复杂度也仅为$O(k n)$。具有线性内存复杂度的Flash attention(其自注意力计算速度可比标准注意力快达$\times$25倍)支持交叉注意力,我们借此实现性能提升。然而这些实现未针对图学习优化,我们发现最大瓶颈并非注意力计算,而是存储稠密边邻接掩码需要与边数量平方成正比的内存。一个明显但尚未实现的优化方案是采用稀疏张量格式存储掩码。此外,PyTorch中单个布尔值需占用整字节存储,导致理论内存使用量增加8倍。补充材料SI 5讨论了更多限制与优化可能。尽管存在这些限制,我们已成功训练出处理约30,000条边规模图数据(如表6中dd数据集)的ESA模型。

Table 1: The table reports root mean squared error on qm9 (RMSE is the standard metric for quantum mechanics) and $\mathrm{{R^{2}}}$ for dockstring (dock) and Molecule Net (mn), presented as mean $\pm$ standard deviation over 5 runs. The mean absolute error (MAE) is reported for pcqm4mv2 (MAE is the standard metric for this task) over a single run due to the size of the dataset. The lowest MAEs and RMSEs, and the highest $\mathrm{R^{2}}$ values are highlighted in bold. oom denotes out-of-memory errors. A complete table, including GCN and GIN, is provided in Supplementary Table 1.

表 1: 该表报告了qm9的均方根误差(RMSE是量子力学的标准指标)以及dockstring(dock)和Molecule Net(mn)的$\mathrm{{R^{2}}}$值,以5次运行的平均值$\pm$标准差呈现。由于数据集规模,pcqm4mv2报告了单次运行的平均绝对误差(MAE是该任务的标准指标)。最低的MAE和RMSE值以及最高的$\mathrm{R^{2}}$值以粗体标出。oom表示内存不足错误。完整表格(包括GCN和GIN)见补充表1。

目标 DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
μ ELUMO △E (R²) EHOMO 0.55 ± 0.01 0.44 ± 0.03 0.55 ± 0.01 0.48 ± 0.02 0.55 ± 0.01 0.46 ± 0.02 0.53 ± 0.01 0.48 ± 0.07 0.40 ± 0.63 ± 0.01 0.02 0.76 ± 0.02 0.45 ± 0.01 0.94 ± 0.17 0.60 ± 0.13 0.56 ± 0.00 0.40 ± 0.00
0.11 ± 0.00 0.11 ± 0.00 0.10 ± 0.00 0.10 ± 0.00 0.11 ± 0.00 0.12 ± 0.00 0.11 ± 0.00 0.10 ± 0.00
0.12 ± 0.00 0.12 ± 0.00 0.13 ± 0.00 0.11 ± 0.00 0.11 ± 0.00 0.13 ± 0.00 0.11 ± 0.00 0.11 ± 0.00
0.17 ± 0.00 0.16 ± 0.01 0.16 ± 0.00 0.14 ± 0.00 0.16 ± 0.00 0.18 ± 0.01 0.15 ± 0.01 0.15 ± 0.00
29.87 ± 0.34 30.91 ± 0.15 30.15 ± 0.36 28.50 ± 0.44 29.63 ± 0.35 31.54 ± 0.42 30.42 ± 1.19 28.33 ± 0.32
ZPVE 0.03 ± 0.00 0.03 ± 0.00 0.03 ± 0.00 0.03± 0.00 0.06 ± 0.03 0.03 ± 0.00 0.03 ± 0.01 0.03 ± 0.00
29.82 ± 6.32 27.82 ± 3.98 25.40 ± 3.62 31.64 ± 6.49 24.60 ± 4.70 15.48 ± 4.67 14.50 ± 4.34 4.78 ± 0.67
Uo 24.74 ± 6.31 30.91 ± 2.62 24.92 ± 4.18 25.04 ± :5.48 15.55 ± 8.23 12.45 ± 1.86 15.82 ± 4.26 6.80 ± 2.81
U H 34.02 ± 5.66 28.92 ± 4.11 23.42 ± 2.77 27.34 士 8.17 19.01 ± 4.65 11.42 ± 3.72 12.92 ± 4.29 6.02 ± 1.32
G 25.41 ± 4.86 31.49 ± 1.86 23.24 ± 2.54 22.31 士 3.45 31.20 ± 4.61 29.72 ± 8.79 13.08 ± 4.81 6.10 ± 1.33
CV 0.19 ± 0.01 0.19 ± 0.01 0.19 ± 0.00 0.18± 0.01 0.18 ± 0.02 0.18 ± 0.00 0.19 ± 0.03 0.16 ± 0.00
0.38 ± 0.02 0.39 ± 0.03 0.38 ± 0.02 0.35 ± 0.02 0.29± 0.00 0.30 ± 0.02 0.33 ± 0.05 0.24 ± 0.01
UATOM 0.40 ± 0.04 0.48 ± 0.07 0.40 ± 0.04 0.36± 0.02 0.30± 0.01 0.30 ± 0.00 0.33 ± 0.06 0.24 ± 0.01
HATOM GATOM 0.38 ± 0.04 0.38 ± 0.02 0.41 ± 0.06 0.37 士 0.03 0.30± 0.01 0.31 ± 0.01 0.36 ± 0.09 0.25 ± 0.00
A 0.37 ± 0.04 0.34 ± 0.02 0.33 ± 0.01 0.31 1.01 士 0.02 0.26 ± 0.01 0.27 ± 0.01 0.36 ± 0.08 0.22 ± 0.02
B 0.90 ± 0.06 0.21 ± 0.04 0.97 ± 0.12 0.21 ± 0.04 1.08 ± 0.16 0.26 ± 0.01 0.26 士 0.10 ±0.02 0.10 ± 64.88 ± 29.77 :0.03 3.82 ± 2.05 0.11 ± 0.01 1.42 ± 0.44 0.16 ± 0.05 0.75 ± 0.11
C 0.19 ± 0.05 0.27 ± 0.02 0.27 ± 0.01 0.28 ± 0.00 0.28 ± 0.01 0.12 ± :0.04 0.10 ± 0.02 0.12 ± 0.05 0.08 ± 0.01 0.05 ± 0.01
ESR2 F2 0.68 ± 0.00 0.67 ± 0.00 0.65 ± 0.00 0.70 ± 0.00 OOM 0.64 ± 0.01 0.68 ± 0.00 0.70 ± 0.00
KIT 0.89 ± 0.00 0.89 ± 0.00 0.89 ± 0.00 0.89 ± 0.00 0OM 0.87 ± 0.01 0.88 ± 0.00 0.89 ± 0.00
DOCK 0.83 ± 0.00 0.83 ± 0.00 0.83 ± 0.00 0.84 ± 0.00 OOM 0.80 ± 0.01 0.83 ± 0.00 0.84 ± 0.00
PARP1 0.92 ± 0.00 0.92 ± 0.00 0.92 ± 0.00 0.92 ± 0.00 OOM 0.91 ± 0.00 0.92 ± 0.01 0.93 ± 0.00
PGR 0.70 ± 0.00 0.68 ± 0.01 0.67 ± 0.01 0.72 ± 0.00 OOM 0.68 ± 0.01 0.70 ± 0.01 0.73 ± 0.00
FREESOLV 0.97 ± 0.00 0.96 ± 0.01 0.97 ± 0.01 0.95 ± 0.01 0.93 ± :0.00 0.93 ± 0.02 0.86 ± 0.03 0.98 ± 0.00
LIPO 0.81 ± 0.01 0.82 ± 0.01 0.82 ± 0.01 0.83 ± 0.01 0.61 ± :0.04 0.55 ± 0.02 0.79 ± 0.00 0.81 ± 0.01
NN ESOL 0.94 ± 0.01 0.93 ± 0.01 0.93 ± 0.00 0.94 ± 0.01 0.91 ± 0.02 0.89 ± 0.03 0.91 ± 0.00 0.94 ± 0.00
PCQM4MV2(↓) N/A N /A N /A N/A N /A N/ A N /A 0.0235

Table 2: The table reports mean absolute error (MAE) on the zinc dataset, presented as mean $\pm$ standard deviation over 5 runs. The best/lowest value is highlighted in bold. A complete table, including GCN, GIN, and DropGIN is provided in Supplementary Table 2.

表 2: 该表报告了锌数据集上的平均绝对误差 (MAE),以5次运行的平均值 $\pm$ 标准差表示。最佳/最低值以粗体标出。完整表格(包括GCN、GIN和DropGIN)见补充表2。

数据集 (↓) GAT GATv2 PNA Graphormer TokenGT GPS ESA ESA (PE)
ZINC 0.078 ± 0.01 0.079 ± 0.00 0.057 ± 0.01 0.036 ± 0.00 0.047 ± 0.01 0.024 ± 0.01 0.027 ± 0.00 0.017 ± 0.00

4 Results

4 结果

We perform a comprehensive evaluation of ESA on 70 different tasks, including domains such as molecular property prediction, vision graphs, and social networks, as well as different aspects of representation learning on graphs, ranging from node-level tasks with homophily and he t ero phil y graph types to modelling long range dependencies, shortest paths, and 3D atomic systems. We quantify the performance of our approach relative to 6 GNN baselines: GCN, GAT, GATv2, PNA, GIN, and DropGIN (more expressive than 1-WL), and 3 graph transformer baselines: Graphormer, TokenGT, and GraphGPS. All the details on hyper-parameter tuning, rationale for the selected metrics, and selection of baselines can be found in SI 7.1 to SI 7.3. In the remainder of the section, we summarize our findings across molecular learning, mixed graph-level tasks, node-level tasks, and ablations on the interleaving operator along with insights on time and memory scaling.

我们对ESA在70种不同任务上进行了全面评估,涵盖分子属性预测、视觉图和社会网络等领域,以及图表示学习的多个方面,包括同配性和异配性图类型的节点级任务,以及长程依赖建模、最短路径和3D原子系统等任务。我们将该方法的性能与6种GNN基线(GCN、GAT、GATv2、PNA、GIN和DropGIN(比1-WL更具表达能力)以及3种图Transformer基线(Graphormer、TokenGT和GraphGPS)进行了量化比较。关于超参数调优、所选指标的合理性及基线选择的详细信息,请参见SI 7.1至SI 7.3。在本节剩余部分,我们将总结在分子学习、混合图级任务、节点级任务以及交错算子消融实验中的发现,并探讨时间和内存扩展的见解。

4.1 Molecular Learning

4.1 分子学习

As learning on molecules has emerged as one of the most successful applications of graph learning, we present an in-depth evaluation including quantum mechanics, molecular docking, and various physical chemistry and biophysics benchmarks, as well as an exploration of learning on 3D atomic systems, transfer learning, and learning on large molecules of therapeutic relevance (peptides).

随着分子学习成为图学习最成功的应用之一,我们开展了深度评估,涵盖量子力学、分子对接、多种物理化学与生物物理学基准测试,并探索了3D原子系统学习、迁移学习以及对具有治疗相关性的生物大分子(如多肽)的学习。

QM9. We report results for all 19 qm9 [48] targets in Table 1, with GCN and GIN separately in Supplementary Table 1 due to space restrictions. We observe that on 15 out of 19 properties, ESA is the best performing model. The exceptions are the frontier orbital energies (HOMO and LUMO energy, HOMO-LUMO gap) and the dipole moment ( $\mu$ ), where PNA is slightly ahead of it. Other graph transformers are competitive relative to GNNs on many properties but vary in performance across tasks.

QM9. 我们在表1中报告了所有19个qm9 [48]目标的结果,由于篇幅限制,GCN和GIN的结果单独列在补充表1中。我们观察到,在19个性质中的15个上,ESA是表现最好的模型。例外情况是前沿轨道能量(HOMO和LUMO能量,HOMO-LUMO间隙)和偶极矩 ($\mu$),其中PNA略微领先。其他图Transformer (Transformer) 在许多性质上相对于GNN具有竞争力,但在不同任务中表现各异。

DOCKSTRING. dockstring [49] is a recent drug discovery data collection consisting of molecular docking scores for 260,155 small molecules and 5 high-quality targets from different protein families that were selected as a regression benchmark, with different levels of difficulty: parp1 (enzyme, easy), f2 (protease, easy to medium), kit (kinase, medium), esr2 (nuclear receptor, hard), and pgr (nuclear receptor, hard). We report results for the 5 targets in Table 1 (and Supplementary Table 1) and observe that ESA is the best performing method on 4 out of 5 tasks, with PNA slightly ahead on the medium-difficulty kit. TokenGT and GraphGPS do not generally match ESA or even PNA. Molecular docking scores also depend heavily on the 3D geometry, as discussed in the original paper [49], posing a difficult challenge for all methods. Interestingly, not only ESA but all tuned GNN baselines outperform the strongest method in the original manuscript (Attentive FP, a GNN based on attention [50]) despite using 20,000 less training molecules (which we leave out as a validation set). This illustrates the importance of evaluation relative to baselines with tuned hyper-parameters.

DOCKSTRING。dockstring [49] 是近期推出的药物发现数据集,包含260,155个小分子与5个来自不同蛋白质家族的高质量靶点的分子对接(docking)评分,这些靶点被选作回归基准并具有不同难度等级:parp1(酶类,简单)、f2(蛋白酶类,简单至中等)、kit(激酶类,中等)、esr2(核受体类,困难)和pgr(核受体类,困难)。我们在表1(及附表1)中报告了5个靶点的结果,观察到ESA在5项任务中有4项表现最佳,仅在中等难度的kit靶点上略逊于PNA。TokenGT和GraphGPS普遍未能达到ESA甚至PNA的水平。如原论文[49]所述,分子对接评分高度依赖3D几何结构,这对所有方法都构成了严峻挑战。值得注意的是,不仅ESA,所有经过调优的GNN基线模型都超越了原论文中最强方法(基于注意力机制的GNN模型Attentive FP [50]),尽管我们少用了20,000个训练分子(留作验证集)。这说明了相对于经过超参数调优的基线进行评估的重要性。

Table 3: The table reports Matthews correlation coefficient (MCC) for graph-level molecular classification tasks from Molecule Net and National Cancer Institute (nci), presented as mean $\pm$ standard deviation over 5 different runs. oom denotes out-of-memory errors. The highest mean values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 4 with MCC as metric, and in Supplementary Table 5 with accuracy.

表 3: 该表报告了来自Molecule Net和美国国家癌症研究所(nci)的图级分子分类任务的马修斯相关系数(MCC),以5次不同运行的平均值±标准差呈现。oom表示内存不足错误。最高平均值以粗体标出。完整表格(包括GCN和GIN)见补充表4(MCC指标)和补充表5(准确率指标)。

Data (↑) DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MN BBBP 0.68 ±0.02 0.74 ± 0.01 0.73±0.03 0.73 ±0.03 0.55± 0.01 0.58±0.07 0.70 ±0.04 0.84±0.01
BACE 0.65 ± 0.03 0.63 ± 0.02 0.64 ± 0.03 0.64 ± 0.02 0.52± 0.02 0.58 ± 0.03 0.62 ± 0.03 0.72 ± 0.02
HIV 0.46 ± 0.03 0.42 ± 0.06 0.34 ± 0.06 0.42 ± 0.04 OOM 0.46 ± 0.02 0.25 ± 0.21 0.53 ± 0.01
NC11 0.69±0.03 0.70±0.02 0.65±0.03 0.70±0.03 0.54 ± 0.02 0.53±0.03 0.70±0.03 0.75 ±0.01
NCI NC1109 0.68 ± 0.02 0.66 ± 0.01 0.66 ± 0.01 0.67 ± 0.02 0.50 ± 0.02 0.45 ± 0.03 0.62 ± 0.01 0.70 ± 0.01

Table 4: The table reports mean absolute error (MAE) and average precision (AP) for two long-range molecular benchmarks involving peptides. The molecular graph of a peptide is much larger than that of a small drug-like molecule and this makes the tasks well-suited for long-range benchmarking. All the results except the ones for ESA and TokenGT are extracted from [29]. The number of layers for pept-struct, respectively pept-func is given as $(\cdot/\cdot)$ .

表 4: 该表报告了涉及肽的两个长程分子基准测试的平均绝对误差 (MAE) 和平均精度 (AP)。肽的分子图比类药小分子大得多,这使得这些任务非常适合长程基准测试。除 ESA 和 TokenGT 的结果外,所有结果均摘自 [29]。pept-struct 和 pept-func 的层数分别表示为 $(\cdot/\cdot)$。

数据集 GCN (6/6) GIN (10/8) GPS (8/6) TokenGT (10/10) ESA (3/4)
PEPT-STR (MAE ↓) 0.2460±0.0007 0.2473±0.0017 0.2509±0.0014 0.2489±0.0013 0.2453±0.0003
PEPT-FN (AP↑) 0.6860±0.0050 0.6621±0.0067 0.6534±0.0091 0.6263±0.0117 0.6863±0.0044

Molecule Net and NCI. We report results for a varied selection of three regression and three classification benchmarks from Molecule Net, as well as two benchmarks from the National Cancer Institute (NCI) consisting of compounds screened for anti-cancer activity (Tables 1 and 3, and Supplementary Tables 1 and 4). We also report the accuracy in Supplementary Table 5. Except hiv, these datasets pose a challenge to graph transformers due to their small size ( $<$ 5,000 compounds). With the exception of the Lip oph ili city (lipo) dataset from Molecule Net, we observe that ESA is the preferred method, and often by a significant margin, for example on bbbp and bace, despite their small size (2,039, respectively 1,513 total compounds before splitting). On the other hand, Graphormer and TokenGT perform poorly, possibly due to the small-sample nature of these tasks. GraphGPS is closer to the top performers but still under performs compared to GAT(v2) and PNA. These results also show the importance of appropriate evaluation metrics, as the accuracy on hiv for all methods is above 97% (Supplementary Table 5), but the MCC (Table 3) is significantly lower.

Molecule Net与NCI。我们报告了来自Molecule Net的三个回归和三个分类基准的多样化选择结果,以及来自美国国家癌症研究所(NCI)的两个基准,这些基准包含筛选抗癌活性的化合物(表1和表3,及补充表1和表4)。补充表5中还报告了准确率。除hiv外,这些数据集由于规模较小( $<$ 5,000个化合物)对图Transformer提出了挑战。除Molecule Net中的Lipophilicity(lipo)数据集外,我们观察到ESA是首选方法,且通常优势显著,例如在bbbp和bace上(拆分前分别为2,039和1,513个化合物)。另一方面,Graphormer和TokenGT表现不佳,可能是由于这些任务的少样本特性。GraphGPS更接近顶级方法,但仍逊色于GAT(v2)和PNA。这些结果也显示了适当评估指标的重要性,因为所有方法在hiv上的准确率都超过97%(补充表5),但MCC(表3)明显更低。

PCQM4MV2. pcqm4mv2 is a quantum chemistry benchmark introduced as a competition through the Open Graph Benchmark Large-Scale Challenge (OGB-LSC) project [51]. It consists of 3,378,606 training molecules with the goal of predicting the DFT-calculated HOMO-LUMO energy gap from 2D molecular graphs. It has been widely used in the literature, especially to evaluate graph transformers [25, 26, 52]. Since the test splits are not public, we use the available validation set (73,545 molecules) as a test set, as is often done in the literature, and report results on it after training for 400 epochs. We report results from a single run, as it is common in the field due to the large size of the dataset [25, 26, 52]. For the same reason, we do not run our baselines and instead chose to focus on the state-of-the-art results published on the official leader board∗. At the time of writing (November 2024), the best performing model achieves a validation set MAE of 0.0671 [53]. Here, ESA achieves a MAE of 0.0235, which is almost 3 times lower. It it worth noting that the top 3 methods for this dataset are all bespoke architectures designed for molecular learning, for example Uni-Mol+ [54], Transformer-M [55], and TGT [53]. In contrast, ESA is general purpose (does not use any information or technique specifically for molecular learning), uses only the 2D input graph, and does not use any positional or structural encodings.

PCQM4MV2

pcqm4mv2是由开放图基准大规模挑战赛(OGB-LSC)项目[51]推出的量子化学基准测试数据集。该数据集包含3,378,606个训练分子,目标是从二维分子图中预测DFT计算的HOMO-LUMO能隙。该数据集在学术界被广泛使用,尤其用于评估图Transformer模型[25, 26, 52]。由于测试集未公开,我们按照文献惯例使用现有验证集(73,545个分子)作为测试集,并在训练400轮后报告结果。由于数据集规模庞大,我们与领域内常见做法一致[25, 26, 52],仅报告单次运行结果。出于同样原因,我们没有重新运行基线模型,而是选择聚焦官方排行榜*公布的最新成果。截至本文撰写时(2024年11月),性能最佳的模型在验证集上达到0.0671的MAE[53]。而ESA模型取得了0.0235的MAE,性能提升近3倍。值得注意的是,该数据集当前排名前三的方法都是专门针对分子学习设计的架构,例如Uni-Mol+[54]、Transformer-M[55]和TGT[53]。相比之下,ESA是通用架构(未使用任何分子学习专用信息或技术),仅使用二维输入图,且不依赖任何位置或结构编码。

ZINC. We report results on the full zinc dataset with 250,000 compounds (Table 2), which is commonly used for generative purposes [56, 57]. This is one of the only benchmarks where the graph transformer baselines (Graphormer, TokenGT, GraphGPS) convincingly outperformed strong GNN baselines. ESA, without positional encodings, slightly under performs compared to GraphGPS, which uses random-walk structural encodings (RWSE). This type of encoding is known to be beneficial in molecular tasks, and especially for zinc [27, 58]. Thus, we also evaluated an ESA $^+$ RWSE model, which increased relative performance by almost 40%. While there is no leader board available for the full version of zinc, recently Ma et al. [31] evaluated their own method (GRIT) against 11 other baselines, including higher-order GNNs, with the best reported MAE being $\mathbf{0.023\pm0.001}$ . This shows that ESA can already almost match state-of-the-art models without structural encodings, and significantly improves upon this result when augmented.

ZINC。我们在包含25万种化合物的完整ZINC数据集上报告结果(表2),该数据集通常用于生成目的[56,57]。这是少数几个图Transformer基线模型(Graphormer、TokenGT、GraphGPS)明显优于强GNN基线的基准之一。未使用位置编码的ESA略逊于采用随机游走结构编码(RWSE)的GraphGPS。已知此类编码对分子任务(尤其是ZINC)有益[27,58]。因此我们还评估了ESA$^+$RWSE模型,其相对性能提升近40%。虽然完整版ZINC没有公开排行榜,但最近Ma等人[31]评估了其方法(GRIT)与11个其他基线(包括高阶GNN)的对比,最佳报告MAE为$\mathbf{0.023\pm0.001}$。这表明ESA即使不加结构编码也能接近最先进模型,增强后更能显著提升结果。

Table 5: A summary of the transfer learning performance on qm9 for HOMO and LUMO properties, presented as mean $\pm$ standard deviation over 5 different runs. The metric is the root mean squared error (RMSE). All the models use the 3D atomic coordinates and atom types as inputs and no other node or edge features. ‘Strat.’ stands for strategy and specifies the type of learning: GW only (no transfer learning), inductive, or trans duct ive. The lowest values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 3.

表 5: qm9数据集上HOMO和LUMO性质迁移学习性能汇总,结果为5次运行的平均值$\pm$标准差。评估指标为均方根误差(RMSE)。所有模型均使用3D原子坐标和原子类型作为输入,不使用其他节点或边特征。"Strat."表示策略,指定学习类型:仅GW(无迁移学习)、归纳式(inductive)或传导式(transductive)。最低值以粗体标出。完整表格(含GCN和GIN)见补充材料表3。

任务 Strat. DropGIN GAT GATv2 PNA Grph. TokenGT GPS ESA
HOMO GW 0.162 ± 0.00 0.159 ± 0.00 0.157 ± 0.00 0.151 ± 0.00 0.179 ± 0.01 0.200 ± 0.01 0.162 ± 0.00 0.152 ± 0.00
Ind. 0.136 ± 0.00 0.131 ± 0.00 0.133 ± 0.00 0.132 ± 0.00 0.134 ± 0.00 0.156 ± 0.00 0.151 ± 0.00 0.131 ± 0.00
Trans. 0.126 ± 0.00 0.123 ± 0.00 0.124 ± 0.00 0.121 ± 0.00 0.125 ± 0.00 0.137 ± 0.00 0.147 ± 0.00 0.119 ± 0.00
LUMO GW 0.180 ± 0.00 0.181 ± 0.00 0.178 ± 0.00 0.174 ± 0.00 0.190 ± 0.01 0.204 ± 0.01 0.178 ± 0.00 0.174 ± 0.00
Ind. 0.159 ± 0.00 0.156 ± 0.00 0.157 ± 0.00 0.156 ± 0.00 0.151 ± 0.00 0.165 ± 0.00 0.167 ± 0.00 0.150 ± 0.00
Trans. 0.157 ± 0.00 0.153 ± 0.00 0.153 ± 0.00 0.153 ± 0.00 0.147 ± 0.00 0.156 ± 0.00 0.169 ± 0.00 0.146 ± 0.00

Long-range peptide tasks. Graph learning with transformers has traditionally been evaluated on longrange graph benchmarks (LRGB) [59]. However, it was recently shown that simple graph neural networks outperform most attention-based methods [29]. We selected two long-range benchmarks involving peptide property prediction: peptides-struct and peptides-func. From the LRGB collection, these two stand out due to having the longest average shortest path ( $20.89\pm9.79$ versus $10.74\pm0.51$ for the next) and the largest average diameter $(56.99\pm28.72\$ versus $27.62\pm2.13$ ). We report ESA results against tuned GNNs and GraphGPS models from [59], as well as our own optimised TokenGT model (Table 4). Despite using only half the number of layers as other methods or less, ESA outperformed these baselines and matched the second model on the pept-struct leader board, and is within the top 5 for pept-func (as of November 2024).

长程肽任务。传统上,基于Transformer的图学习主要在长程图基准(LRGB) [59]上进行评估。但近期研究表明,简单的图神经网络(GNN)能超越大多数基于注意力机制的方法[29]。我们选取了两个涉及肽属性预测的长程基准:peptides-struct和peptides-func。在LRGB数据集中,这两个任务因具有最长的平均最短路径($20.89\pm9.79$,而次长仅为$10.74\pm0.51$)和最大的平均直径($56.99\pm28.72$,而次大为$27.62\pm2.13$)而突出。表4展示了ESA方法与[59]中调优的GNN、GraphGPS模型以及我们优化的TokenGT模型的对比结果。尽管ESA使用的层数仅为其他方法的一半或更少,其性能仍超越基线模型,并在pept-struct排行榜上与第二名持平,同时在pept-func任务中保持前五名(截至2024年11月)。

Learning on 3D atomic systems. We adapt a subset from the Open Catalyst Project (OCP) [60, 61] for evaluation, with the full steps in SI 9, including deriving edges from atom positions based on a cutoff, pre-processing specific to crystal systems, and encoding atomic distances using Gaussian basis functions. As a prototype, we compare NSA (MAE of $\mathbf{0.799\pm0.008}$ ) against Graphormer $0.839\pm0.005$ ), one of the best models that have been used for the Open Catalyst Project [62]. These encouraging results on a subset of OCP motivated us to further study modelling 3D atomic system through the lens of transfer learning.

学习3D原子系统。我们采用Open Catalyst Project (OCP) [60, 61]的子集进行评估,完整步骤见SI 9,包括基于截断值从原子位置推导边缘、针对晶体系统的特定预处理,以及使用高斯基函数编码原子距离。作为原型,我们将NSA (MAE为$\mathbf{0.799\pm0.008}$)与Graphormer ($0.839\pm0.005$)进行对比,后者是Open Catalyst Project [62]中使用的最佳模型之一。这些在OCP子集上取得的鼓舞性结果,促使我们通过迁移学习的视角进一步研究3D原子系统的建模。

Transfer learning on frontier orbital energies. We follow the recipe recently outlined for drug discovery and quantum mechanics by [32] and leverage a recent, refined version of the qm9 HOMO and LUMO energies [63] that provides alternative DFT calculations and new calculations at the more accurate GW level of theory. As outlined by [32], transfer learning can occur trans duct iv ely or inductively. The transfer learning scenario is important and is detailed in SI 4. We use a 25K/5K/10K train/validation/test split for the high-fidelity GW data, and we train separate low-fidelity DFT models on the entire dataset (trans duct ive) or with the high-fidelity test set molecules removed (inductive). Since the HOMO and LUMO energies depend to a large extent on the molecular geometry, we use the 3D-aware version of ESA from the previous section, and adapt all of our baselines to the 3D setup. Our results are reported in Table 5. Without transfer learning (strategy ‘GW’ in Table 5), ESA and PNA are almost evenly matched, which is already an improvement since PNA was better for frontier orbital energies without 3D structures (Table 1), while the graph transformers perform poorly. Employing transfer learning all the methods improve significantly, but ESA outperforms all baselines for both HOMO and LUMO, in both trans duct ive and inductive tasks.

前沿轨道能的迁移学习。我们遵循[32]近期针对药物发现和量子力学提出的方法,并利用qm9 HOMO和LUMO能级的最新优化版本[63],该版本提供了替代的DFT计算以及在更精确的GW理论层面的新计算结果。如[32]所述,迁移学习可通过传导式(transductive)或归纳式(inductive)进行。该迁移学习场景非常重要,详见SI 4。我们对高精度GW数据采用25K/5K/10K的训练/验证/测试集划分,并在完整数据集上训练独立的低精度DFT模型(传导式),或移除高精度测试集分子后训练(归纳式)。由于HOMO和LUMO能级很大程度上取决于分子几何结构,我们使用前文所述的3D感知版ESA,并将所有基线方法适配至3D架构。结果如表5所示:未采用迁移学习时(表5中"GW"策略),ESA与PNA表现接近,这已是进步——因为在无3D结构时PNA对前沿轨道能更具优势(表1),而图Transformer模型表现较差;采用迁移学习后所有方法均有显著提升,但ESA在HOMO和LUMO的传导式与归纳式任务中均优于所有基线方法。

4.2 Mixed Graph-level Tasks

4.2 混合图级别任务

Other than molecular learning, we examine a suite of graph-level benchmarks from various domains, including computer vision, bioinformatics, synthetic graphs, and social graphs. Our results are reported in Table 6 and Supplementary Table 7, where ESA generally outperforms all the baselines. In the context of state-of-theart models from online leader boards, of particular note are the accuracy results (provided in Supplementary Table 7) on the vision datasets, where on mnist ESA matches the best performing model [64], and is in the top 5 on cifar10 at the time of submission (November 2024), without positional or structural encodings or other helper mechanisms. Similarly, on the malnettiny dataset of function call graphs popularised by GraphGPS [27], we achieve the highest recorded accuracy.

除分子学习外,我们还测试了来自计算机视觉、生物信息学、合成图和社会图等多个领域的图级基准。结果如表6和附表7所示,其中ESA普遍优于所有基线方法。值得注意的是,在在线排行榜的当前最优模型背景下,ESA在视觉数据集上的准确率结果(见附表7)表现突出:在mnist数据集上匹配了性能最佳的模型[64],在提交时(2024年11月)的cifar10数据集上位列前五,且未使用位置编码、结构编码或其他辅助机制。同样,在GraphGPS[27]推广的函数调用图数据集malnettiny上,我们取得了迄今最高的记录准确率。

4.3 Node-level Benchmarks

4.3 节点级基准测试

Node-level tasks are an interesting challenge for our proposed approach as the PMA module is not needed and node-level propagation is the most natural strategy of learning node representations. To this end, we did prototype with an edge-to-node pooling module which would allow ESA to learn node embeddings, with good results. However, the approach does not currently scale to the millions of edges that some he t ero philo us graphs have. For these reasons, we revert to the simpler node-set attention (NSA) formulation. Table 7 summarizes our results, indicating that NSA performs well on all 11 node-level benchmarks, including homo philo us (citation), he t ero philo us, and shortest path tasks. We have adapted Graphormer and TokenGT for node classification as this functionality was not originally available, although they require integer node and edge features which restricts their use on some datasets (denoted by n/a in Table 7). NSA achieves the highest MCC score on homo philo us and he t ero philo us graphs, but GCN has the highest accuracy on cora (Supplementary Table 9). Some methods, for instance GraphSAGE [67] are designed to perform particularly well under he t ero phil y and to account for that we include GraphSAGE [67] and Graph Transformer [33] as the two top performing baselines from Platonov et al. [66] in our extended results in Supplementary Table 10. With the exception of the amazon ratings datasets, we outperform the two baselines by a noticeable margin. Our results on chameleon stand out in particular, as well as on the shortest path benchmarks, where other graph transformers are unable to learn and even PNA fails to be competitive.

节点级任务对我们提出的方法来说是一个有趣的挑战,因为不需要PMA模块,而节点级传播是最自然的节点表示学习策略。为此,我们确实尝试了边缘到节点的池化模块原型,使ESA能够学习节点嵌入,并取得了良好效果。但该方法目前无法扩展到某些异质图拥有的数百万条边。基于这些原因,我们回归到更简单的节点集注意力(NSA)方案。表7总结了我们的结果,表明NSA在所有11个节点级基准测试中表现良好,包括同质(引用)、异质和最短路径任务。我们调整了Graphormer和TokenGT用于节点分类(该功能原本不可用),但它们需要整数节点和边特征,这限制了在某些数据集上的使用(表7中标记为n/a)。NSA在同质图和异质图上取得了最高MCC分数,但GCN在cora数据集上准确率最高(补充表9)。某些方法(如GraphSAGE [67])专为在异质性下表现优异而设计,为此我们在补充表10的扩展结果中纳入了Platonov等人[66]的两个最佳基线方法GraphSAGE [67]和Graph Transformer [33]。除amazon评分数据集外,我们以明显优势超越了这两个基线。我们在chameleon数据集和最短路径基准上的表现尤为突出——其他图Transformer模型无法有效学习,甚至PNA也缺乏竞争力。

Table 6: The table reports Matthews correlation coefficient (MCC) for graph-level classification tasks from various domains, presented as mean $\pm$ standard deviation over 5 different runs. oom denotes out-of-memory errors, and n/a that the model is unavailable (e.g., node/edge features are not integers, which are required for Graphormer and TokenGT). The poor performance of GraphGPS on malnettiny in Table 6 can be explained by our use of one-hot degrees as node features for datasets lacking pre-computed features, while GraphGPS originally used the more informative local degree profile [65]. The highest mean values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 6 with MCC as metric, and in Supplementary Table 7 with accuracy.

表 6: 该表报告了不同领域图级分类任务的马修斯相关系数 (MCC),以5次运行的平均值 ± 标准差呈现。oom表示内存不足错误,n/a表示模型不可用(例如,Graphormer和TokenGT需要整数节点/边特征)。表6中GraphGPS在malnettiny上的较差表现可归因于我们对缺乏预计算特征的数据集使用独热编码度作为节点特征,而GraphGPS原本使用信息量更大的局部度分布 [65]。最高平均值以粗体标出。完整表格(包括GCN和GIN)在补充表6(以MCC为指标)和补充表7(以准确率为指标)中提供。

Data DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MALNETTINY 0.90 ± 0.01 0.90 ± 0.01 0.90 ± 0.01 0.91 ± 0.01 OOM 0.78±( 0.01 0.79 ± 0.01 0.93 ± 0.00
MNIST 0.97 干 0.00 0.97 士 0.00 0.98 士 0.00 0.98 士 0.00 N/A N/A 0.98 土 0.00 0.99 士 0.00
CIFAR10 0.61 ± 0.01 0.66 士 0.01 0.66 士 0.01 0.69 ± 0.01 N/A N/A 0.71 士 0.01 0.73 士( 0.00
ENZYMES 0.58 ± 0.04 0.75 ± 0.02 0.74 )干1 0.02 0.68 ± 0.03 N / A N/A 0.73 士 0.05 0.75 ±0.01
10. B PROTEINS 0.46 ± 0.01 0.46 ± 0.04 0.49 ± 0.04 0.47 ± 0.07 N/A N/A 0.44 0.02 0.59 ± 0.02
DD 0.54 ± 0.07 0.47 ± 0.05 0.53 ± 0.03 0.56 ± 0.08 OOM 0.46 ± 0.05 0.60 ± 0.05 0.65 ± 0.03
HLN SYNTH 1.00 ± 0.00 1.00 土 0.00 1.00 土 0.00 1.00 ± 0.00 N / A N/A 1.00 土 0.00 1.00 1.00
YN SYNT N. 0.97 ± 0.05 0.76: ± 0.07 0.91 0.07 1.00 ± 0.00 N/A N/A 1.00 土 0.00 1.00 士( 0.00
S SYNTHIE 0.94 ± 0.02 0.70 ± 0.04 0.80 ± 0.04 0.88 ± 0.06 N/A N/A 0.95 ± 0.02 0.95 ± 0.02
IMDB-B 0.61 ± 0.06 0.69 ±0.04 0.60 0.05 0.56 ± 0.07 0.56 ± 0.61 ± 0.05 0.60 士 0.05 0.74 ± 0.03
CIAL IMDB-M 0.12 ± 0.10 0.20 ± 0.05 0.20 ± 0.03 0.03 ± 0.07 0.22 干 0.20 ± 0.02 0.23 士 0.02 0.25 ±0.03
TWITCHE. 0.38 ± 0.01 0.37 ± 0.01 0.37 ± 0.01 0.08 ± 0.16 0.03 0.39 干 0.00 0.00 0.39 ± 0.00 0.40 干 0.00 0.40 ± 0.00
S RDT.THR. 0.56 ± 0.01 0.53 ± 0.02 0.54 ± 0.02 0.11 ± 0.23 0.57 ± 0.56 ± 0.00 0.57 ± 0.00 0.57 ± 0.00

Table 7: The table reports Matthews correlation coefficient (MCC) for 11 node-level classification tasks, presented as mean $\pm$ standard deviation over 5 different runs. The number of nodes for the shortest path (sp) benchmarks is given in parentheses (based on randomly-generated “infected Erd?s-Rényi (ER) graphs; details are provided in SI 8). For squirrel and chameleon, we used the filtered datasets introduced by Platonov et al. [66] to fix existing data leaks. The highest mean values are highlighted in bold. Additional he t ero phil y results are provided in Supplementary Table 10. A complete table, including GIN and DropGIN, is provided in Supplementary Table 8 with MCC as the performance metric, and in Supplementary Table 9 with accuracy.

表 7: 该表报告了11个节点级分类任务的马修斯相关系数 (MCC),以5次不同运行的平均值 $\pm$ 标准差呈现。最短路径 (sp) 基准测试的节点数在括号中给出(基于随机生成的“感染Erd?s-Rényi (ER) 图”;详细信息见SI 8)。对于squirrel和chameleon,我们使用了Platonov等人[66]引入的过滤数据集来修复现有的数据泄露问题。最高平均值以粗体标出。补充表10提供了额外的异质性结果。完整表格(包括GIN和DropGIN)在补充表8中提供(以MCC作为性能指标),在补充表9中提供(以准确率作为指标)。

Data (↑) GCN GAT GATv2 PNA Graphormer TokenGT GPS NSA
PPI 0.98 ± 0.00 0.99 ± 0.00 0.98 0.01 0.99 ± 0.00 N/A N/A N/A 0.99 0.00
CITESEER 0.61 干 0.01 0.59 士 0.03 0.61 0.01 0.51 干 0.03 OOM 0.38 ± 0.02 0.54 ± 0.01 0.63 士 0.00
CIT CORA 0.77 ± 0.01 0.75 0.01 0.73: ± 0.01 0.64 ± 0.03 OOM 0.37 ± 0.18 0.64 ± 0.04 0.77 ± 0.00
ROMAN E. 0.47 ± 0.00 0.74 士 0.01 0.76 ± 0.00 0.86 ± 0.00 N/A N/A 0.84 ± 0.01 0.87 ± 0.00
O AMAZON R. 0.18 ± 0.00 0.26 0.01 0.25 ± 0.01 0.21 ± 0.02 N/A N/A 0.11 ± 0.14 0.34 ± 0.01
MINESWEEPER 0.30 ± 0.00 0.48 土 0.02 0.51 ± 0.01 0.62 ± 0.04 N/A N/A 0.56 ± 0.01 0.69 ± 0.00
HE TOLOKERS 0.30 ± 0.01 0.38 土 0.01 0.39 ± 0.00 0.35± 0.05 N/A N/A 0.35 ± 0.02 0.43 ± 0.00
SQUIRREL 0.20 ± 0.01 0.24 土 0.02 0.24 ± 0.01 0.22± 0.01 0.10 ± 0.09 0.18 ± 0.02 0.23 ± 0.02 0.29 ± 0.01
CHAMELEON 0.32 ± 0.01 0.24 干 0.06 0.28 ± 0.03 0.27 ± 0.03 0.25± 0.04 0.26± 0.04 0.30 ± 0.08 0.39 9±0.02
SP ER (15K) 0.22 2土( 0.02 0.32 士 0.00 0.32± 0.00 0.54 ± 0.09 OOM 0.06 ± 0.00 0.18 ± 0.04 0.92 士( 0.01
ER (30K) 0.09 ± 0.03 0.10 ± 0.06 0.10 ± 0.06 0.42 ± 0.05 OOM OOM OOM 0.87 ± 0.01

Table 8: Example of multiple ESA configurations and their impact on performance via three different kinds of benchmarks: a graph-level molecular task (bbbp), a graph-level vision task (mnist), and a node-level he t ero philo us task (chameleon). In the model configurations, $^{6}\mathrm{M}^{1}$ ’ denotes a MAB, ‘S’ a SAB before the PMA module and ‘S’ afterwards, and ‘P’ is the PMA module. The performance metric is the Matthews correlation coefficient (MCC).

表 8: 多种ESA配置示例及其通过三类基准测试对性能的影响:图级分子任务(bbbp)、图级视觉任务(mnist)和节点级异质任务(chameleon)。模型配置中,$^{6}\mathrm{M}^{1}$'表示MAB,'S'表示PMA模块前的SAB,'P'表示PMA模块。性能指标为马修斯相关系数(MCC)。

Dataset Model MCC (↑) Dataset Model MCC (↑) Dataset Model MCC (↑)
MMMMSPS 0.845 SSMMMMMMMPS 0.986 MSMSMS 0.422
BBBP MMMSPSS 0.835 SMSMSMSMSMP 0.983 CHAME- SSMMSS 0.384
SSSMMPS 0.812 MNIST SSSMMMSSSP 0.982 LEON SMSMSM 0.378
MMMMMP 0.782 MSMSMSMSMSP 0.980 M M M M M 0.359
SMSMSP 0.768 0.980 SSMMMM 0.351

4.4 Effects of Varying the Layer Order and Type

4.4 层序与类型变化的影响

In Table 8, we summarise the results of an ablation relative to the interleaving operator with different order and types of layers in our architecture. Smaller datasets perform well with 4 to 6 feature extraction layers, while larger datasets with more complex graphs like mnist benefit from up to 10 layers. We have generally observed that the top configurations tend to include self-attention layers at the front, with masked attention layers in the middle and self-attention layers at the end, surrounding the PMA readout. Naive configurations such as all-masked layers or simply alternating masked and self-attention layers do not tend to be optimal for graph-level prediction tasks. This ablation experiment demonstrates the importance of vertical combination of masked and self-attention layers for the performance of our model.

在表8中,我们总结了关于架构中不同顺序和类型层的交错操作符的消融实验结果。较小数据集在4到6个特征提取层时表现良好,而具有更复杂图结构的大型数据集(如mnist)则受益于多达10层的设计。我们普遍观察到,最优配置往往在前端包含自注意力层,中间为掩码注意力层,末端再放置自注意力层,从而环绕PMA读出模块。纯掩码层或简单交替掩码与自注意力层的朴素配置通常不适用于图级预测任务。该消融实验证明了掩码层与自注意力层垂直组合对模型性能的重要性。


Figure 4: The elapsed time for training a single epoch for different datasets (in seconds), and the maximum allocated memory during this training epoch (GB). Different configurations are tested, while varying hidden dimensions and number of layers. qm9 has around 130K small molecules (a maximum of 29 nodes and 56 edges), dockstring has around 260K graphs that are around 6 times larger, and finally mnist has 70K graphs which are around 11-12 times larger than qm9. We use dummy integer features for TokenGT when benchmarking mnist. Graphormer runs out of memory for dockstring and mnist.

图 4: 不同数据集训练单轮耗时(秒)及该训练轮次最大内存占用(GB)。测试了不同隐藏维度和层数的配置。qm9包含约13万个小分子(最多29个节点和56条边),dockstring包含约26万个规模约为qm9 6倍的图,mnist包含7万个规模约为qm9 11-12倍的图。基准测试mnist时,TokenGT使用虚拟整数特征。Graphormer在dockstring和mnist上出现内存不足。

4.5 Time and Memory Scaling

4.5 时间和内存扩展

The theoretical properties of ESA are discussed in Methods, Section 3.3, where we also cover deep learning library limitations and possible optimisation s (also see SI 5). Here, we have empirically evaluated the time and memory scaling of ESA against all 9 baselines. For a fair evaluation, we implement Flash attention for TokenGT as it was not originally supported. GraphGPS natively supports Flash attention, while Graphormer requires specific modifications to the attention matrix which are not currently supported.

ESA的理论特性在方法部分的第3.3节中讨论,我们还涵盖了深度学习库的局限性及可能的优化方案(另见SI 5)。本文通过实证评估了ESA与全部9个基线方法在时间和内存占用上的扩展性。为确保公平性,我们为TokenGT实现了Flash attention机制(该机制原版本未支持)。GraphGPS原生支持Flash attention,而Graphormer需要对注意力矩阵进行特定修改(当前版本尚未支持该功能)。

We report the training time for a single epoch and the maximum allocated memory during training for qm9 and mnist in Figure 4, and dockstring in Supplementary Figure 1. In terms of training time, GCN and GIN are consistently the fastest due to their simplicity and low number of parameters (also see Figure 6a). ESA is usually the next fastest, followed by other graph transformers. The strong GNN baselines, particularly PNA, and to an extent GATv2 and GAT, are among the slowest methods. In terms of memory, GCN and GIN are again the most efficient, followed by GraphGPS and TokenGT. ESA is only slightly behind, with PNA, DropGIN, GATv2, and GAT all being more memory intensive, particularly DropGIN.

我们在图4中报告了qm9和mnist的单轮训练时间及训练期间最大内存占用,dockstring数据则在补充图1中展示。就训练时间而言,由于结构简单且参数量少(另见图6a),GCN和GIN始终是最快的。ESA通常是第二快的,其次是其他图Transformer模型。强GNN基线方法(尤其是PNA)以及GATv2和GAT在某种程度上属于最慢的方法。内存方面,GCN和GIN同样最节省资源,其次是GraphGPS和TokenGT。ESA略逊一筹,而PNA、DropGIN、GATv2和GAT都更耗内存,其中DropGIN尤为明显。

We also order all methods according to their achieved performance and illustrate their rank relative to the total time spent training and the maximum memory allocated during training (Figure 5). ESA occupies the top left corner in the time plot, confirming its efficiency. The results are more spread out regarding memory, however the performance is still remarkable considering that ESA works over edges, whose number rapidly increases relative to the number of nodes that dictates the scaling of all other methods.

我们还根据各方法实现的性能进行排序,并展示它们在总训练时间和训练期间最大内存分配上的相对排名(图5)。ESA在时间图中位于左上角,证实了其高效性。在内存方面结果分布较分散,但考虑到ESA基于边(其数量相对节点数快速增长,而节点数决定了其他所有方法的扩展性)运作,其性能仍然非常出色。

Finally, we report the number of parameters for all methods and their configurations (Figure 6a). Apart from GCN, GIN, and DropGIN, ESA has the lowest number of parameters. GAT and GATv2 have rapidly increasing parameter counts, likely due to the concatenation that happens between attention heads, while PNA also increases significantly quicker than ESA. Lastly, we demonstrate and discuss one of the bottlenecks

最后,我们报告了所有方法及其配置的参数数量(图 6a)。除 GCN、GIN 和 DropGIN 外,ESA 的参数数量最低。GAT 和 GATv2 的参数数量增长迅速,可能是由于注意力头之间的拼接操作,而 PNA 的增长速度也明显快于 ESA。最后,我们展示并讨论了其中一个瓶颈


Figure 5: All methods illustrated according to their achieved performance (rank) versus the total time spent training (minutes) in Panel (a) and the maximum allocated memory (GB) in Panel (b).

图 5: 所有方法根据其达到的性能(排名)与训练总耗时(分钟)在面板(a)中的对比,以及最大分配内存(GB)在面板(b)中的对比。


(a) The number of parameters (in millions) for all the methods (b) The time and memory util is ation for a selection of Graphormer and different configurations in terms of the hidden dimension configurations. The varying parameters are the number of layers, and the number of layers. and the number of input graphs.

图 1:
(a) 所有方法的参数量(单位:百万) (b) Graphormer及不同隐藏维度配置下的时间和内存利用率对比。变量参数包括层数和输入图数量。

Figure 6: The number of parameters for all methods in Panel (a) and an illustration of the time and memory bottleneck in Graphormer in Panel (b). For Graphormer, even training on a single graph from hiv takes over 10 seconds and slightly under 5GB. Increasing the number of graphs to 8 and setting the batch size to 8 to ensure parallel computation increases th