[论文翻译]基于端到端注意力机制的图学习方法


原文地址:https://arxiv.org/pdf/2402.10793v2


An end-to-end attention-based approach for learning on graphs

基于端到端注意力机制的图学习方法

Abstract

摘要

There has been a recent surge in transformer-based architectures for learning on graphs, mainly motivated by attention as an effective learning mechanism and the desire to supersede handcrafted operators characteristic of message passing schemes. However, concerns over their empirical effectiveness, s cal ability, and complexity of the pre-processing steps have been raised, especially in relation to much simpler graph neural networks that typically perform on par with them across a wide range of benchmarks. To tackle these shortcomings, we consider graphs as sets of edges and propose a purely attention-based approach consisting of an encoder and an attention pooling mechanism. The encoder vertically interleaves masked and vanilla self-attention modules to learn an effective representations of edges, while allowing for tackling possible mis specifications in input graphs. Despite its simplicity, the approach outperforms fine-tuned message passing baselines and recently proposed transformer-based methods on more than 70 node and graph-level tasks, including challenging long-range benchmarks. Moreover, we demonstrate state-of-the-art performance across different tasks, ranging from molecular to vision graphs, and he t ero philo us node classification. The approach also outperforms graph neural networks and transformers in transfer learning settings, and scales much better than alternatives with a similar performance level or expressive power.

近年来,基于Transformer的图学习架构激增,主要驱动力在于注意力机制作为一种高效学习方式,以及取代手工设计消息传递算子的需求。然而,这些架构在实证效果、可扩展性和预处理复杂度等方面存在争议,尤其是在性能与更简单的图神经网络(GNN)相当的广泛基准测试中。为解决这些缺陷,我们将图视为边集合,提出了一种纯注意力架构,包含编码器和注意力池化机制。该编码器通过垂直堆叠掩码自注意力与常规自注意力模块,既能学习有效的边表征,又能处理输入图可能存在的错误定义。尽管结构简单,该方法在70多项节点级和图级任务(包括具有挑战性的长程基准测试)中超越了精细调优的消息传递基线和近期提出的Transformer方法。此外,我们在从分子图到视觉图、异配节点分类等不同任务中实现了最先进性能。在迁移学习场景下,该方法同样优于图神经网络和Transformer模型,且在同等性能或表达能力下展现出更优的扩展性。

1 Introduction

1 引言

We investigate empirically the potential of a purely attention-based approach for learning effective representations of graph structured data. Typically, learning on graphs is modelled as message passing – an iterative process that relies on a message function to aggregate information from a given node’s neighbourhood and an update function to incorporate the encoded message into the output representation of the node. The resulting graph neural networks (GNNs) typically stack multiple such layers to learn node representations based on vertex rooted sub-trees, essentially mimicking the one-dimensional Weisfeiler–Lehman (1-WL) graph isomorphism test [1, 2]. Variations of message passing have been applied effectively in different fields such as life sciences [3–8], electrical engineering [9], and weather prediction [10].

我们通过实证研究探讨了纯基于注意力(attention)的方法在学习图结构数据有效表征方面的潜力。通常,图学习被建模为消息传递(message passing)——这是一个依赖消息函数聚合给定节点邻域信息,并利用更新函数将编码信息整合到节点输出表征中的迭代过程。由此产生的图神经网络(GNNs)通常会堆叠多个此类层,基于顶点根子树学习节点表征,本质上是在模拟一维Weisfeiler-Lehman(1-WL)图同构测试[1,2]。消息传递的变体已成功应用于生命科学[3-8]、电气工程[9]和天气预报[10]等不同领域。

Despite the overall success and wide adoption of graph neural networks, several practical challenges have been identified over time. While the message passing framework is highly flexible, the design of new layers is a challenging research problem where improvements take years to achieve and often rely on hand-crafted operators. This is particularly the case for general purpose graph neural networks that do not exploit additional input modalities such as atomic coordinates. For instance, principal neighbourhood aggregation (PNA) is regarded as one of the most powerful message passing layers [11], but it is built using a collection of manually selected neighbourhood aggregation functions, requires a dataset degree histogram which must be pre-computed prior to learning, and further uses manually selected degree scaling. The nature of message passing also imposes certain limitations which have shaped the majority of the literature. One of the most prominent examples is the readout function used to combine node-level features into a single graph-level representation, and which is required to be permutation invariant with respect to the node order. Thus, the default choice for graph neural networks and even graph transformers remains a simple, non-learnable function such as sum, mean, or max [12–14]. Limitations of this approach have been identified by Wagstaff et al. [15], who have shown that simple readout functions might require complex item embedding functions that are difficult to learn using standard neural networks. Additionally, graph neural networks have shown limitations in terms of over-smoothing [16], linked to node representations becoming similar with increased depth, and over-squashing [17] due to information compression through bottleneck edges. The proposed solutions typically take the form of message regular is ation schemes [18–20]. However, there is generally no consensus on the right architectural choices for building effective deep message passing neural networks. Transfer learning and strategies like pre-training and fine-tuning are also less ubiquitous in graph neural networks because of modest or ambiguous benefits, as opposed to large language models [21].

尽管图神经网络(Graph Neural Networks)取得了整体成功并得到广泛应用,但随着时间的推移也暴露出若干实际挑战。虽然消息传递框架具有高度灵活性,但新层的设计仍是艰巨的研究难题,其改进往往需要数年时间且通常依赖手工设计的算子。这种情况在不利用原子坐标等额外输入模态的通用图神经网络中尤为明显。例如,主邻域聚合(PNA)被视为最强大的消息传递层之一[11],但它采用了一组人工选择的邻域聚合函数,需要预先计算数据集度直方图,并进一步使用人工选择的度缩放参数。消息传递的本质特性也带来了一些限制,这些限制塑造了该领域的大部分研究。最典型的例子是将节点级特征合并为图级表征的读出函数(readout function),该函数必须对节点顺序保持排列不变性。因此,图神经网络甚至图Transformer的默认选择仍是简单的不可学习函数,如求和、均值或最大值[12-14]。Wagstaff等人[15]指出这种方法的局限性,证明简单读出函数可能需要复杂的项嵌入函数,而标准神经网络难以学习这种函数。此外,图神经网络还存在过平滑(over-smoothing)[16]和过挤压(over-squashing)[17]的局限:前者随深度增加导致节点表征趋同,后者源于瓶颈边的信息压缩。现有解决方案通常采用消息正则化方案[18-20],但对于构建高效深度消息传递神经网络的架构选择仍缺乏共识。与大语言模型[21]不同,由于收益有限或效果不明确,迁移学习及预训练-微调等策略在图神经网络中的应用也相对较少。

The attention mechanism [22] is one of the main sources of innovation within graph learning, either by directly incorporating attention within message passing [23, 24], by formulating graph learning as a language processing task [25, 26], or by combining vanilla GNN layers with attention layers [12, 13, 27, 28]. However, several concerns have been raised regarding the performance, s cal ability, and complexity of such methods.

注意力机制 [22] 是图学习创新的主要来源之一,其实现方式包括:直接在消息传递中引入注意力 [23, 24]、将图学习建模为语言处理任务 [25, 26],或将传统 GNN 层与注意力层结合 [12, 13, 27, 28]。然而,这类方法在性能、可扩展性和复杂度方面仍存在争议。

Performance-wise, recent reports indicate that sophisticated graph transformers under perform compared to simple but tuned GNNs [29, 30]. This line of work highlights the importance of empirically evaluating new methods relative to strong baselines. Separately, recent graph transformers have focused on increasingly more complex helper mechanisms, such as computationally expensive pre-processing and learning steps [25], various different encodings (e.g., positional, structural, and relational) [25–27, 31], inclusion of virtual nodes and edges [25, 28], conversion of the problem to natural language processing [25, 26], and other non-trivial graph transformations [28, 31]. These complications can significantly increase the computational requirements, reducing the chance of being widely adopted and replacing GNNs.

性能方面,近期报告表明,复杂的图Transformer模型表现不如经过调优的简单图神经网络(GNN) [29, 30]。这项工作强调了新方法相对于强基线进行实证评估的重要性。另一方面,近期的图Transformer研究聚焦于日益复杂的辅助机制,例如计算成本高昂的预处理和学习步骤 [25]、多种编码方式(如位置编码、结构编码和关系编码)[25–27, 31]、引入虚拟节点和边 [25, 28]、将问题转化为自然语言处理任务 [25, 26] 以及其他非平凡的图变换操作 [28, 31]。这些复杂性会显著增加计算需求,降低其被广泛采用并取代GNN的可能性。

Motivated by the effectiveness of attention as a learning mechanism and recent advances in efficient and exact attention, we introduce an end-to-end attention-based architecture for learning on graphs that is simple to implement, scalable, and achieves state-of-the-art results. The proposed architecture considers graphs as sets of edges, leveraging an encoder that interleaves masked and self-attention mechanisms to learn effective representations. The attention-based pooling component mimics the functionality of a readout function and it is responsible for aggregating the edge-level features into a permutation invariant graph-level representation. The masked attention mechanism allows for learning effective edge representations originating from the graph connectivity and the combination with self-attention layers vertically allows for expanding on this information while having a strong prior. Masking can, thus, be seen as leveraging specified relational information and its vertical combination with self-attention as a means to overcome possible mis specification of the input graph. The masking operator is injected into the pairwise attention weight matrix and allows only for attention between linked primitives. For a pair of edges, the connectivity translates to having a shared node between them. We focus primarily on learning through edge sets due to empirically high performance, and refer to our architecture as edge-set attention (ESA). We also demonstrate that the overall architecture is effective for propagating information across nodes through dedicated node-level benchmarks. The ESA architecture is general purpose, in the sense that it only relies on the graph structure and possibly node and edge features, and it is not restricted to any particular domain. Furthermore, ESA does not use positional, structural, relative, or similar encodings, it does not encode graph structures as tokens or other language (sequence) specific concepts, and it does not require any pre-computations.

受到注意力机制作为学习机制的有效性以及高效精确注意力最新进展的启发,我们提出了一种基于注意力的端到端图学习架构。该架构实现简单、可扩展性强,并能取得最先进的性能表现。我们提出的方法将图视为边集合,通过交替使用掩码注意力和自注意力机制的编码器来学习有效表征。基于注意力的池化组件模拟了读出函数的功能,负责将边级别特征聚合为排列不变的图级别表征。

掩码注意力机制能够基于图连通性学习有效的边表征,与自注意力层的纵向组合则能在保持强先验的前提下扩展信息。因此,掩码操作可视为利用特定关系信息的手段,其与自注意力的纵向结合能够克服输入图可能存在的错误设定。掩码算子被注入成对注意力权重矩阵,仅允许相连图元之间建立注意力连接。对于边对而言,连通性表现为两者共享一个节点。由于实证研究显示边集合学习具有卓越性能,我们主要聚焦于此方法,并将该架构命名为边集注意力(ESA)。通过专门的节点级基准测试,我们还证明了该架构在跨节点信息传播方面的有效性。

ESA架构具有通用性,仅依赖于图结构及可能的节点/边特征,不受特定领域限制。此外,ESA不使用位置编码、结构编码、相对编码等附加信息,不将图结构编码为token或其他语言(序列)特定概念,也无需任何预计算步骤。

Despite its apparent simplicity, ESA-based learning overwhelmingly outperforms strong and tuned GNN baselines and much more involved transformer-based models. Our evaluation is extensive, totalling 70 datasets and benchmarks from different domains such as quantum mechanics, molecular docking, physical chemistry, biophysics, bioinformatics, computer vision, social networks, functional call graphs, and synthetic graphs. At the node level, we include both homo philo us and he t ero philo us graph tasks, as well as shortest path problems, as these require modelling long-range interactions. Beyond supervised learning tasks, we explore the potential for transfer learning in the context of drug discovery and quantum mechanics [32] and show that ESA is a viable transfer learning strategy compared to vanilla GNNs and graph transformers.

尽管看似简单,基于ESA(Ego-graph Subgraph Adaptation)的学习方法显著优于经过调优的GNN基线模型和更复杂的基于Transformer的模型。我们的评估范围广泛,涵盖来自量子力学、分子对接、物理化学、生物物理学、生物信息学、计算机视觉、社交网络、函数调用图和合成图等不同领域的70个数据集和基准测试。在节点级别,我们同时包含同配性(homo philo us)和异配性(he t ero philo us)图任务,以及最短路径问题,因为这些任务需要对长程交互进行建模。除了监督学习任务外,我们还探索了在药物发现和量子力学[32]背景下迁移学习的潜力,并表明与普通GNN和图Transformer相比,ESA是一种可行的迁移学习策略。

2 Related Work

2 相关工作

An attention mechanism that mimics message passing and limits the attention computations to neighbouring nodes has been first proposed in GAT [23]. We consider masking as an abstraction of the GAT attention operator that allows for building general purpose relational structures between items in a set (e.g., $k$ -hop neighbourhoods or conformation al masks extracted from 3D molecular structures). The attention mechanism in GAT is implemented as a single linear projection matrix that does not explicitly distinguish between keys, queries, and values as in standard dot product attention and an additional linear layer after concatenating the representations of connected nodes, along with a non-linearity. This type of simplified attention has been labelled by subsequent work as static and was shown to have limited expressive power [24]. Brody et al. instead proposed dynamic attention, a simple reordering of operations in GAT, resulting in a more expressive GATv2 model [24] – however, at the price of doubling the parameter count and the corresponding memory consumption. A high-level overview of GAT in the context of masked attention is provided in Figure 1, along with the main differences to the proposed architecture. Similar to GAT, several adaptations of the original scaled dot product attention have been proposed for graphs [33, 34], where the focus was on defining an attention mechanism constrained by the node connectivity and replacing the positional encodings of the original transformer model with more appropriate graph alternatives, such as Laplacian ei gen vectors. These approaches, while interesting and forward-looking, did not convincingly outperform simple GNNs. Building up on this line of work, an architecture that can be seen as an instance of masked transformers has been proposed in SAN [35], illustrated in Figure 2. Attention coefficients there are defined as a convex combination of the scores (controlled by hyper-parameter $\gamma$ ) associated with the original graph and its complement. Min et al. [36] have also considered a masking mechanism for standard transformers. However, the graph structure itself is not used directly in the masking process as the devised masks correspond to four types of prior interaction graphs (induced, similarity, cross neighbourhood, and complete sub-graphs), acting as an inductive bias. Furthermore, the method was not designed for general purpose graph learning as it relies on helper mechanisms such as neighbour sampling and a heterogeneous information network.

一种模仿消息传递并将注意力计算限制在相邻节点上的注意力机制最初在GAT [23]中被提出。我们将掩码视为GAT注意力算子的一种抽象,它允许在集合中的项目之间构建通用的关系结构(例如,$k$-跳邻域或从3D分子结构中提取的构象掩码)。GAT中的注意力机制被实现为单一的线性投影矩阵,不像标准点积注意力那样明确区分键、查询和值,并且在连接相连节点的表示后还有一个额外的线性层和非线性激活。这种简化的注意力被后续工作标记为静态的,并被证明表达能力有限[24]。Brody等人提出了动态注意力,这是对GAT中操作顺序的简单重新排列,从而产生了更具表达力的GATv2模型[24]——然而,代价是参数数量和相应的内存消耗翻倍。图1提供了在掩码注意力背景下GAT的高层次概述,以及与所提出架构的主要区别。与GAT类似,针对图结构提出了几种原始缩放点积注意力的改编版本[33, 34],重点是定义受节点连接性约束的注意力机制,并用更合适的图替代方案(如拉普拉斯特征向量)替换原始Transformer模型的位置编码。这些方法虽然有趣且具有前瞻性,但并未明显超越简单的GNN。基于这一系列工作,SAN [35]提出了一种可被视为掩码Transformer实例的架构,如图2所示。其中的注意力系数被定义为与原始图及其补图相关的分数(由超参数$\gamma$控制)的凸组合。Min等人[36]也考虑了标准Transformer的掩码机制。然而,图结构本身并未直接用于掩码过程,因为设计的掩码对应于四种先验交互图(诱导图、相似图、交叉邻域图和完全子图),作为归纳偏置。此外,该方法并非为通用图学习而设计,因为它依赖于邻居采样和异构信息网络等辅助机制。


Figure 1: A high level overview of the GAT message passing algorithm [23]. Panel A: A GAT layer receives node representations and the adjacency matrix as inputs. First a projection matrix is applied to all the nodes, which are then concatenated pairwise and passed through another linear layer, followed by a non-linearity. The final attention score is computed by the softmax function. Panel B: A GAT model stacks multiple GAT layers and uses a readout function over nodes to generate a graph-level representation. Residual connections are also illustrated as a modern enhancement of GNNs (not included in the original approach). GAT vs ESA: Our masked self-attention module relies on masking within classical scaled dot product attention which comes with different projection matrices for keys, queries, and values. Additionally, we wrap the masked scaled dot product attention with layer normalization and skip connections, the output of which is passed through an MLP. The latter is not done classically in GAT as part of its attention modules nor in the overall architecture. In contrast to GAT that operates over nodes, our masking operator is defined over sets of edges that are connected if they share a node in common. When it comes to the readout, GAT uses sum, mean or max and aggregates over nodes whereas our architecture relies on pooling by multi-head attention and aggregation over learned representations of seed vectors inherent to that module. In terms of the encoder that is responsible for learning representations of set items, ours interleaves masked and self-attention layers vertically whereas in GATs the stacked layers are entirely based on neighbourhood attention.

图 1: GAT消息传递算法的高层概览 [23]。A部分: GAT层接收节点表示和邻接矩阵作为输入。首先对所有节点应用投影矩阵,然后成对拼接并通过另一个线性层,接着经过非线性变换。最终注意力分数由softmax函数计算得出。B部分: GAT模型堆叠多个GAT层,并使用节点上的读出函数生成图级表示。残差连接作为GNN的现代增强技术也被展示(原始方法中未包含)。GAT与ESA对比: 我们的掩码自注意力模块基于经典缩放点积注意力中的掩码机制,并为键、查询和值使用不同的投影矩阵。此外,我们用层归一化和跳跃连接封装掩码缩放点积注意力,其输出通过MLP传递。这一设计既不是GAT注意力模块的经典做法,也不属于其整体架构。与作用于节点的GAT不同,我们的掩码算子定义在共享同一节点的边集合上。在读出阶段,GAT使用求和、均值或最大值对节点进行聚合,而我们的架构依赖于通过多头注意力池化,并聚合该模块固有种子向量的学习表示。就负责学习集合项表示的编码器而言,我们的方法垂直交错掩码层和自注意力层,而GAT的堆叠层完全基于邻域注意力。

Recent trends in learning on graphs are dominated by architectures based on standard self-attention layers as the only learning mechanism, with a significant amount of effort put into representing the graph structure exclusively through positional and structural encodings. Graphormer [25] is one of the most prominent approaches from this class. It comes with an involved and computation-heavy suite of pre-processing steps, involving a centrality, spatial, and edge encoding. For instance, spatial encodings rely on the Floyd–Warshall algorithm that has cubic time complexity in the number of nodes and quadratic memory complexity. Also, the model employs a virtual mean-readout node that is connected to all other nodes in the graph. While Graphormer has originally been evaluated only on four datasets, the results were promising relative to GNNs and SAN. Another related approach is the Tokenized Graph Transformer (TokenGT) [26], which treats all the nodes and edges as independent tokens. To adapt sequence learning to the graph domain, TokenGT encodes the graph information using node identifiers derived from orthogonal random features or Laplacian ei gen vectors, and learnable type identifiers for nodes and edges. TokenGT is provably more expressive than standard GNNs, and can approximate $k$ -WL tests with the appropriate architecture (number of layers, adequate pooling, etc.). However, this theoretical guarantee holds only with the use of positional encodings that typically break the permutation invariance over nodes that is required for consistent predictions over the same graph presented with a different node order. The main strength of TokenGT is its theoretical expressiveness, as it has only been evaluated on a single dataset where it did not outperform Graphormer.

近年来,图学习领域的主流架构主要基于标准自注意力(self-attention)层作为唯一学习机制,并通过大量工作将图结构信息仅通过位置编码和结构编码来表示。Graphormer [25]是该类方法中最具代表性的方案之一,它采用了一套复杂且计算量大的预处理流程,包括中心性编码、空间编码和边编码。例如,其空间编码依赖于Floyd-Warshall算法,该算法具有节点数量的三次方时间复杂度和二次方内存复杂度。此外,该模型还引入了一个与图中所有节点相连的虚拟平均读出节点。虽然Graphormer最初仅在四个数据集上进行评估,但其结果相比图神经网络(GNNs)和SAN展现出优势。另一相关工作是Token化图Transformer(TokenGT) [26],该方法将所有节点和边视为独立token。为适应图域序列学习,TokenGT采用基于正交随机特征或拉普拉斯特征向量的节点标识符,以及可学习的节点/边类型标识符来编码图信息。理论证明TokenGT比标准GNNs更具表达能力,并能通过适当架构(层数、池化方式等)逼近$k$-WL测试。但该理论保证仅在使用了通常会破坏节点排列不变性的位置编码时成立,而这种不变性是对同一图结构在不同节点顺序下保持预测一致性所必需的。TokenGT的主要优势在于理论表达能力,目前仅在单个数据集上评估且未超越Graphormer。

A route involving a hybrid between classical transformers and message passing has been pursued in the GraphGPS framework [37], which combines message passing and transformer layers. As in previous works, GraphGPS puts a large emphasis on different types of encodings, proposing and analysing positional and structural encodings, further divided into local, global, and relative encodings. Exphormer [28] is an evolution of GraphGPS that adds virtual global nodes and sparse attention based on expander graphs. While effective, such frameworks do still rely on message passing and are thus not purely attention-based solutions. Limitations include dependence on approximations (Performer [38] for both, expander graphs for Exphormer), and decreased performance when encodings (GraphGPS) or special nodes (Exphormer) are removed. Notably, Exphormer is the first approach from this class to consider custom attention patterns given by node neighbourhoods.

图GPS框架[37]探索了一条结合经典Transformer与消息传递的混合路径,该框架整合了消息传递层和Transformer层。与先前研究类似,图GPS重点关注多种编码类型,提出并分析了位置编码与结构编码(进一步细分为局部、全局和相对编码)。Exphormer[28]作为图GPS的改进版本,引入了基于扩展图的虚拟全局节点和稀疏注意力机制。尽管效果显著,这类框架仍依赖消息传递机制,并非纯粹的注意力解决方案。其局限性包括对近似方法的依赖(图GPS使用Performer[38],Exphormer采用扩展图),以及当移除编码(图GPS)或特殊节点(Exphormer)时性能下降。值得注意的是,Exphormer是该类方法中首个考虑节点邻域定制注意力模式的方案。

3 Methods

3 方法

In this section, we provide a high-level overview of our architecture and describe the masked attention module that allows for information propagation across edges as primitives. An alternative implementation for the more traditional node-based propagation is also presented. The section starts with a formal definition of the masked self-attention mechanism and then proceeds to describe our end-to-end attention-based approach for learning on graphs, involving the encoder and pooling components (illustrated in Figure 3).

在本节中,我们将从高层次概述架构,并描述以边为基元实现跨边信息传播的掩码注意力模块。同时提供了基于传统节点传播的替代实现方案。本节首先对掩码自注意力机制进行形式化定义,随后阐述我们基于注意力机制的图学习端到端方法,该方法包含编码器和池化组件(如图 3 所示)。


Figure 2: A high level overview of the SAN architecture [35]. Panel A: Two independent attention modules with a tied/shared value projection matrix are used for the input graph and its complement. The outputs of these two modules are convexly combined using a hyper parameter $\gamma$ before being residually added to the input. Panel B: The overall SAN architecture that stacks multiple SAN layers without residual connections, and uses a standard readout function over nodes. SAN vs ESA: What is truly different in relation to this architecture is horizontal versus vertical combination of masked and self-attention. While it is possible to set $\gamma=0$ for some layers and $\gamma=1$ for others, this avenue has never been explored in SAN and it would still result in a slightly different encoder as the query and key projections are different for the input graph and its complement. The empirical analysis in [35] also provides no insights relative to vertical combination of masked and self-attention, nor the impact of attention pooling which is a differentiating factor and an important part of our architecture.

图 2: SAN架构的总体概览 [35]。A部分:输入图及其补图使用两个独立的注意力模块,共享值投影矩阵。这两个模块的输出通过超参数$\gamma$进行凸组合后,以残差方式添加到输入中。B部分:整体SAN架构由多个无残差连接的SAN层堆叠而成,并在节点上使用标准读出函数。SAN与ESA的区别:该架构的真正差异在于掩码注意力和自注意力的横向组合与纵向组合方式。虽然可以为某些层设置$\gamma=0$,其他层设置$\gamma=1$,但SAN中从未探索过这种方案,且由于输入图与其补图的查询和键投影不同,仍会导致编码器存在细微差异。[35]中的实证分析也未涉及掩码注意力与自注意力的纵向组合,或注意力池化的影响(这是我们架构的差异化要素和重要组成部分)。


Figure 3: A high-level overview of Edge Set Attention and its main building blocks. Panel A: An illustration of masked scaled dot product attention with edge inputs. Edge features derived from the edge set are processed by query, key, and value projections as in standard attention. An additive edge mask, derived as in Algorithm 1, is applied prior to the softmax operator. A mask value of 0 leaves the attention score unchanged, while a large negative value completely discards that attention score. Panel B: The Pooling by Multi-Head Attention Block (PMA). A learnable set of seed tensors $\mathbf{S}_{k}$ is randomly initial is ed and used as the query for a cross attention operation with the inputs. The result is further processed by Self Attention Blocks. Panel C: The Self Attention Block (SAB), illustrated here with a pre-layer-normalization architecture. The input is normalised, processed by self attention, then further normalised and processed by an MLP. The block uses residual connections. Panel D: Illustration of the Masked (Self) Attention Block (MAB). The only difference compared to SAB is the edge mask, which follows the computation outlined in Panel A. Panel E: An instantiation of the ESA model. Edge inputs are processed by various different attention blocks, here in the order MSMSP, corresponding to masked, self, masked, self, and pooling attention layers.

图 3: 边集注意力机制及其核心组件的高层概览。A部分: 带边输入的掩码缩放点积注意力示意图。从边集提取的边特征通过标准注意力机制中的查询、键和值投影进行处理。如算法1所述生成的加性边掩码会在softmax算子前应用。掩码值为0时保持注意力分数不变,而较大负值会完全丢弃该注意力分数。B部分: 多头注意力池化块(PMA)。一组可学习的种子张量$\mathbf{S}_{k}$被随机初始化,作为跨注意力操作的查询输入。结果会通过自注意力块进一步处理。C部分: 自注意力块(SAB),此处展示的是层归一化前置架构。输入经过归一化、自注意力处理,再通过MLP进行二次归一化和处理。该模块采用残差连接。D部分: 掩码(自)注意力块(MAB)示意图。与SAB的唯一区别在于边掩码,其计算流程如A部分所述。E部分: ESA模型的具体实现。边输入会经过多种不同的注意力块处理,此处展示的顺序是MSMSP,对应掩码、自注意力、掩码、自注意力和池化注意力层。

3.1 Masked Attention Modules

3.1 掩码注意力模块

We first introduce the basic notation and then give a formal description of the masked attention mechanism with a focus on the connectivity pattern specific to graphs. Following this, we provide algorithms for an efficient implementation of the masking operators using native tensor operations.

我们首先介绍基本符号,然后重点描述针对图结构特定连接模式的掩码注意力机制 (masked attention mechanism) 形式化定义。随后提供基于原生张量运算高效实现掩码算子的算法。

As in most graph learning settings, a graph is a tuple $\mathcal{G}=(\mathcal{N},\mathcal{E})$ where $\mathcal{N}$ represents the set of nodes (vertices), $\mathcal{E}\subseteq\mathcal{N}\times\mathcal{N}$ is the set of edges, and $N_{n}=|\mathcal{N}|$ , $N_{e}=|\mathcal{E}|$ . Nodes are associated with feature vectors ${\bf\delta I}{i}$ of dimension $d_{n}$ for all nodes $i\in\mathcal N$ , and $d_{e}$ -dimensional edge features $\mathbf{e}{i j}$ for all edges $e\in\mathcal{E}$ . The node features are collected as rows in a matrix $\mathbf{N}\in\mathbb{R}^{N_{n}\times d_{n}}$ , and similarly for edge features into $\mathbf{E}\in\mathbb{R}^{N_{e}\times d_{e}}$ . The graph connectivity information can be represented as an adjacency matrix A, where $\mathbf{A}{i j}=1$ if $(i,j)\in\mathcal{E}$ and $\mathbf{A}_{i j}=0$ otherwise. The edge list (edge index) representation is equivalent but more common in practice.

与大多数图学习设定相同,图被定义为一个元组 $\mathcal{G}=(\mathcal{N},\mathcal{E})$ ,其中 $\mathcal{N}$ 表示节点(顶点)集合, $\mathcal{E}\subseteq\mathcal{N}\times\mathcal{N}$ 是边集合,且 $N_{n}=|\mathcal{N}|$ , $N_{e}=|\mathcal{E}|$ 。每个节点 $i\in\mathcal N$ 关联一个维度为 $d_{n}$ 的特征向量 ${\bf\delta I}{i}$ ,每条边 $e\in\mathcal{E}$ 关联一个 $d_{e}$ 维边特征 $\mathbf{e}{i j}$ 。节点特征按行收集为矩阵 $\mathbf{N}\in\mathbb{R}^{N_{n}\times d_{n}}$ ,边特征同理收集为 $\mathbf{E}\in\mathbb{R}^{N_{e}\times d_{e}}$ 。图的连接信息可用邻接矩阵A表示,当 $(i,j)\in\mathcal{E}$ 时 $\mathbf{A}{i j}=1$ ,否则 $\mathbf{A}_{i j}=0$ 。边列表(边索引)表示形式与之等价,但在实践中更为常见。

Algorithm 1: Edge masking in PyTorch Geometric (T is the transpose; helper functions are explained in SI 6).

算法 1: PyTorch Geometric中的边掩码实现(T表示转置;辅助函数说明见SI 6)

| 1 | fromesa import consecutive,first-unique-index |
| 2 | function edge-adjacency(batched_edge-index) | 17 | function edge_mask(b-ei,b_map,B,L) |
| 3 | N_e ←batched_edge_index.size(1) | 18 | mask←torch.full(size=(B,L,L),fill=False) |
| 4 | source_nodes←batched_edge_index[0] | 19 | edge_to-graph ← b_map.index_select(O,b_ei[0,:]) |
| 5 | target_nodes{batched_edge_index[1] | 20 | |
| 6 | | 21 | edge_adj←edge-adjacency(b_ei) |
| 7 | #unsqueeze and expand | 22 | ei_to_originalconsecutive( |
| 8 | exp-src ←source_nodes.unsq(1).expand((-1,N_e)) | 23 | first-unique-index(edge_to-graph),b-ei.size(1)) |
| 6 | exp-trg ←target-nodes.unsq(1).expand((-1,N_e)) | 24 | |
| 10 | | 25 | edges←edge_adj.nonzero() |
| 11 | src-adj← exp-src== T(exp-src) | 26 | graph-index ←edge-to-graph.idx_select(O,edges[:,0]) |
| 12 | trg-adj ← exp-trg== T(exp-trg) | 27 | coord_1←ei_to_original.idx_select(O,edges[:, ,0]) |
| 13 | cross← (exp_src== :T(exp-trg)) logical_or | 28 | coord_2 <ei_to_original.idx_select(0,edges[:,1]) |
| 14 | (exp-trg == T(exp-src)) | 29 | |
| 15 | | 30 | mask[graph-index,coord-1,coord_2]←True |
| 16 | return ( | 31 | return~mask |

The main building block in ESA is masked scaled dot product attention. Panel A in Figure 3 illustrates this attention mechanism schematically. More formally, this attention mechanism is given by

ESA中的主要构建模块是掩码缩放点积注意力 (masked scaled dot product attention)。图 3 中的面板 A 以示意图形式展示了这种注意力机制。更正式地说,该注意力机制由以下公式给出:

$$
\mathrm{SDPA}(\mathbf{Q},\mathbf{K},\mathbf{V},\mathbf{M})=\mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}+\mathbf{M}\right)\mathbf{V}
$$

$$
\mathrm{SDPA}(\mathbf{Q},\mathbf{K},\mathbf{V},\mathbf{M})=\mathrm{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^{\top}}{\sqrt{d_{k}}}+\mathbf{M}\right)\mathbf{V}
$$

where $\mathbf{Q}$ , $\mathbf{K}$ , and $\mathbf{V}$ are projections of the edge representations to queries, keys, and values, respectively. This is a minor extension of the standard scaled dot product attention via the additive mask M that can censor any of the pairwise attention scores, and $d_{k}$ is the key dimension. The generalisation to masked multihead attention follows the same steps as in the original transformer [22]. Below and in the specification of the overall architecture, we refer to this function as MultiHead $({\bf Q},{\bf K},{\bf V},{\bf M})$ .

其中 $\mathbf{Q}$、$\mathbf{K}$ 和 $\mathbf{V}$ 分别是边表示向查询(query)、键(key)和值(value)的投影。这是对标准缩放点积注意力机制的一个小扩展,通过可加性掩码M来屏蔽任意成对注意力分数,$d_{k}$ 是键维度。带掩码的多头注意力机制的泛化遵循原始Transformer [22] 的相同步骤。在下文及整体架构的规范中,我们将此函数称为MultiHead $({\bf Q},{\bf K},{\bf V},{\bf M})$。

Masked self-attention for graphs can be seen as graph-structured attention, and an instance of a generalized attention mechanism with custom attention patterns [28]. More specifically, in ESA the attention pattern is given by an edge adjacency matrix rather than allowing for interactions between all set items. Crucially, the edge adjacency matrix can be efficiently computed both for a single graph and for batched graphs using exclusively tensor operations. The case for a single graph is covered in the left side of Algorithm 1 through the edge adjacency function. The first 3 lines of the function correspond to getting the number of edges, then separating the source and target nodes from the edge adjacency list (also called edge index), which is equivalent to the standard adjacency matrix of a graph. The source and target node tensors each have a dimension equal to the number of edges ( $N_{e}$ ). Lines 7-9 add an additional dimension and efficiently repeat (‘expand’) the existing tensors to shape $(N_{e},N_{e})$ without allocating new memory. Using the transpose, line 11 checks if the source nodes of any two edges are the same, and the same for target nodes on line 12. On line 13, cross connections are checked, where the source node of an edge is the target node of another, and vice versa. The operations of lines 11-14 result in boolean matrices of shape $(N_{e},N_{e})$ which are summed for the final returned edge adjacency matrix. The right panel of Algorithm 1 depicts the case where the input graph represents a batch of smaller graphs. This requires an additional batch mapping tensor that maps each node in the batched graph to its original graph, and carefully manipulating the indices to create the final edge adjacency mask of shape $(B,L,L)$ , where $L$ is the maximum number of edges in the batch.

图的掩码自注意力 (masked self-attention) 可视为图结构注意力,是采用自定义注意力模式的广义注意力机制实例 [28]。具体而言,在ESA中,注意力模式由边邻接矩阵而非所有集合项间的交互决定。关键在于,该边邻接矩阵可通过纯张量运算高效计算,适用于单图和批量图处理。算法1左侧展示了单图情况下的边邻接函数实现:前3行获取边数量,从边邻接列表(亦称边索引,等价于标准图邻接矩阵)分离源节点与目标节点张量,二者维度均等于边数 ($N_{e}$)。第7-9行通过扩展操作(不分配新内存)为张量新增维度并重塑为 $(N_{e},N_{e})$ 形状。第11行利用转置检查任意两条边的源节点是否相同,第12行同理处理目标节点。第13行验证交叉连接(某边的源节点为另一边的目标节点,反之亦然)。第11-14行操作生成 $(N_{e},N_{e})$ 布尔矩阵,经求和得到最终边邻接矩阵。算法1右侧描述批量图输入场景:需额外批次映射张量将节点映射至原图,并通过索引操作生成形状为 $(B,L,L)$ 的边邻接掩码($L$ 为批次中最大边数)。

Since in ESA attention is computed over edges, we chose to separate source and target node features for each edge, similarly to lines 4-5 of Algorithm $^{1}$ , and concatenate them to the edge features:

由于在ESA中注意力是在边上计算的,我们选择将每条边的源节点和目标节点特征分离,类似于算法1的第4-5行,并将它们与边特征拼接起来:

$$
\mathbf{x}{i j}=\mathbf{n}{i}\parallel\mathbf{n}{j}\parallel\mathbf{e}_{i j}
$$

$$
\mathbf{x}{i j}=\mathbf{n}{i}\parallel\mathbf{n}{j}\parallel\mathbf{e}_{i j}
$$

for each edge $e_{i j}$ . The resulting features $\mathbf{x}{i j}$ are collected in a matrix $\mathbf{X}\in\mathbb{R}^{N_{e}\times(2d_{n}+d_{e})}$ .

对于每条边 $e_{i j}$,最终得到的特征 $\mathbf{x}{i j}$ 被收集到矩阵 $\mathbf{X}\in\mathbb{R}^{N_{e}\times(2d_{n}+d_{e})}$ 中。

Having defined the mask generation process and the masked multihead attention function, we next define the modular blocks of ESA, starting with the Masked Self Attention Block (MAB):

在定义了掩码生成过程和掩码多头注意力函数后,我们接下来定义ESA的模块化构建块,从掩码自注意力块 (MAB) 开始:

$$
\begin{array}{r}{\mathrm{MAB}(\mathbf{X},\mathbf{M})=\mathbf{H}+\mathrm{MLP}(\mathrm{LayerNorm}(\mathbf{H}))\quad\quad}\ {\mathbf{H}=\overline{{\mathbf{X}}}+\mathrm{MultiHead}(\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\mathbf{M})\quad\quad}\ {\overline{{\mathbf{X}}}=\mathrm{LayerNorm}(\mathbf{X})\quad\quad}\end{array}
$$

$$
\begin{array}{r}{\mathrm{MAB}(\mathbf{X},\mathbf{M})=\mathbf{H}+\mathrm{MLP}(\mathrm{LayerNorm}(\mathbf{H}))\quad\quad}\ {\mathbf{H}=\overline{{\mathbf{X}}}+\mathrm{MultiHead}(\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\overline{{\mathbf{X}}},\mathbf{M})\quad\quad}\ {\overline{{\mathbf{X}}}=\mathrm{LayerNorm}(\mathbf{X})\quad\quad}\end{array}
$$

where the mask M is computed as in Algorithm 1 and has shape $B\times L\times L$ , with $B$ the batch size, and MLP is a multi-layer perceptron. A Self Attention Block (SAB) can be formally defined as:

其中掩码M按算法1计算得出,形状为$B\times L\times L$(B为批次大小),MLP为多层感知机。自注意力块(SAB)可形式化定义为:

$$
\mathrm{SAB}(\mathbf{X})=\mathrm{MAB}(\mathbf{X},\mathbf{0})
$$

$$
\mathrm{SAB}(\mathbf{X})=\mathrm{MAB}(\mathbf{X},\mathbf{0})
$$

In principle, the masks can be arbitrary and a promising avenue of research could be in designing new mask

原则上,掩码可以是任意的,设计新型掩码可能是一个有前景的研究方向。

Algorithm 2: Node masking in the PyTorch Geometric framework.

算法 2: PyTorch Geometric框架中的节点掩码

行号 代码
1 from torch_geometric import unbatch_edge_index
2 function node_mask(batched_edge_index, batch_map, B, M)
3 # batched_edge_index将所有的edge_index张量批量处理为单个张量
4
5 # B, M分别是批次大小和掩码大小
6 mask ← torch.full(size=(B, M, M), fill=False)
7 graph_idx ← batch_map.index_select(0, batched_edge_index[0, :])
8 edge_index_list ← unbatch_edge_index(batched_edge_index, batch_map)
9 edge_index ← torch.cat(edge_index_list, dim=1)
10 mask[graph_idx, edge_index[0, :], edge_index[1, :]] ← True
11 return ~mask

types. For certain tasks such as node classification, we also defined an alternative version of ESA for nodes (NSA). The only major difference is the mask generation step. The single graph case is trivial as all the masking information is available in the edge index, so we provide the general batched case in Algorithm 2.

对于节点分类等特定任务,我们还定义了节点版ESA的替代方案(NSA)。主要区别仅在于掩码生成步骤。单图场景较为简单,因为所有掩码信息都存在于边索引中,因此我们在算法2中提供了通用的批处理案例。

3.2 ESA Architecture

3.2 ESA架构

Our architecture consists of two components: $\imath$ ) an encoder that interleaves masked and self-attention blocks to learn an effective representation of edges, and ii) a pooling block based on multi-head attention, inspired by the decoder component from the set transformer architecture [39]. The latter comes naturally when one considers graphs as sets of edges (or nodes), as done here. ESA is a purely attention-based architecture that leverages the scaled dot product mechanism proposed by Vaswani et al. [22] through masked- and standard self-attention blocks (MABs and SABs) and a pooling by multi-head attention (PMA) block.

我们的架构由两个组件组成:$\imath$) 一个通过交错掩码注意力块和自注意力块来学习边有效表示的编码器,以及 ii) 一个基于多头注意力的池化块,其灵感来自集合Transformer架构 [39] 的解码器组件。当我们将图视为边(或节点)的集合时,后者自然适用。ESA是一种纯基于注意力的架构,它通过掩码注意力块(MAB)、标准自注意力块(SAB)以及多头注意力池化块(PMA),利用了Vaswani等人 [22] 提出的缩放点积机制。

The encoder consisting of arbitrarily interleaved MAB and SAB blocks is given by:

由任意交错排列的 MAB 和 SAB 块组成的编码器表示为:

$$
\mathrm{B\circ\cdots(X,M),\in{MAB,SAB}}
$$

$$
\mathrm{B\circ\cdots(X,M),其中\in{MAB,SAB}}
$$

Here, ‘AB’ refers to an attention block that can be instantiated as an MAB or SAB.

此处,“AB”指代一个注意力块 (attention block),可实例化为MAB或SAB。

The pooling module that is responsible for aggregating the processed edge representations into a graph-level representation is formally defined by:

负责将处理后的边表征聚合为图级表征的池化模块正式定义为:

$$
\begin{array}{r l}{\mathrm{PMA}{k,p}(\mathbf{Z})=\mathrm{SAB}^{p}\left(\overline{{\mathbf{S}}}+\mathrm{MLP}(\overline{{\mathbf{S}}})\right)}\ {\qquad\overline{{\mathbf{S}}}=\mathrm{LayerNorm}\left(\mathrm{MultiHead}\left(\mathbf{S}_{k},\mathbf{Z},\mathbf{Z},\mathbf{0}\right)\right)}\end{array}
$$

$$
\begin{array}{r l}{\mathrm{PMA}{k,p}(\mathbf{Z})=\mathrm{SAB}^{p}\left(\overline{{\mathbf{S}}}+\mathrm{MLP}(\overline{{\mathbf{S}}})\right)}\ {\qquad\overline{{\mathbf{S}}}=\mathrm{LayerNorm}\left(\mathrm{MultiHead}\left(\mathbf{S}_{k},\mathbf{Z},\mathbf{Z},\mathbf{0}\right)\right)}\end{array}
$$

where $\mathbf{S}{k}$ is a tensor of $k$ learnable seed vectors that are randomly initial is ed and $\mathrm{SAB}^{p}(\cdot)$ is the application of $p$ SABs. Technically, it suffices to set $k=1$ to output a single representation for the entire graph. However, we have empirically found it beneficial to set it to a small value, such as $k=32$ . Moreover, this change allows self attention (SABs) to further process the $k$ resulting representations, which can be simply summed or averaged due to the small $k$ . Contrary to classical readouts that aggregate directly over set items (i.e., nodes), pooling by multi-head attention performs the final aggregation over the embeddings of learnt seed vectors $S_{k}$ . While tasks involving node-level predictions require only the Encoder component, the predictive tasks involving graph-level representations require all the modules, both the encoder and pooling by multi-head attention. The architecture in the latter setting is formally given by

其中 $\mathbf{S}{k}$ 是一个包含 $k$ 个可学习种子向量的张量,这些向量被随机初始化,而 $\mathrm{SAB}^{p}(\cdot)$ 表示应用 $p$ 个 SAB。技术上,设置 $k=1$ 足以输出整个图的单一表示。然而,我们通过实验发现将其设置为较小值(如 $k=32$)更为有利。此外,这一调整使得自注意力机制(SABs)能够进一步处理得到的 $k$ 个表示,由于 $k$ 值较小,这些表示可以直接求和或取平均。与直接在集合项(即节点)上进行聚合的经典读出方法不同,多头注意力池化是在学习到的种子向量 $S_{k}$ 的嵌入表示上执行最终聚合。涉及节点级预测的任务仅需编码器组件,而涉及图级表示的任务则需要所有模块,包括编码器和多头注意力池化。后一种场景下的架构形式化定义为

$$
\mathbf{Z}{\mathrm{out}}=\mathrm{PMA}_{k,p}(\mathrm{Encoder}(\mathbf{X},\mathbf{M})+\mathbf{X})
$$

$$
\mathbf{Z}{\mathrm{out}}=\mathrm{PMA}_{k,p}(\mathrm{Encoder}(\mathbf{X},\mathbf{M})+\mathbf{X})
$$

As optimal configurations are task specific, we do not explicitly fix all the architectural details. For example, it is possible to select between layer and batch normalisation, a pre-LN or post-LN architecture [40], or standard and gated MLPs, along with GLU variants, e.g., SwiGLU [41].

由于最优配置与任务相关,我们并未明确固定所有架构细节。例如,可以在层归一化与批量归一化之间选择,采用前置层归一化 (pre-LN) 或后置层归一化 (post-LN) 架构 [40],或选择标准MLP与门控MLP(如SwiGLU [41]等GLU变体)。

3.3 Time and Memory Scaling

3.3 时间和内存扩展

ESA is enabled by recent advances in efficient and exact attention, such as memory-efficient attention and Flash attention [42–45]. Flash attention does not yet support masking, but memory-efficient implementations with arbitrary masking capabilities exist in both PyTorch and xFormers [46, 47]. Theoretically, the memory complexity of memory-efficient approaches is $O(\sqrt{n})$ , where $n$ is the sequence/set size, with the same time complexity as standard attention. Since all operations in ESA except the cross-attention in PMA are based on self-attention, our method benefits directly from all of these advances. Moreover, our cross-attention is performed between the full set of size $n$ and a small set of $k\leq32$ learnable seeds, such that the complexity of the operation is $O(k n)$ even for standard attention. Flash attention, which has linear memory complexity even for self-attention and is up to $\times$ 25 times faster than standard attention, supports cross-attention and we use it for further uplifts. However, these implementations are not optimised for graph learning and we have found that the biggest bottleneck is not the attention computation, but rather storing the dense edge adjacency mask, which requires memory quadratic in the number of edges. An evident, but not yet available, optimisation could amount to storing the masks in a sparse tensor format. Alternatively, even a single boolean requires an entire byte of storage in PyTorch, increasing the theoretical memory usage by 8 times. Further limitations and possible optimisation s are discussed in SI 5. Despite some of these limitations, we have successfully trained ESA models for graphs with up to approximately 30,000 edges (e.g., dd in Table 6).

ESA得益于近年来高效精确注意力机制的进步,如内存高效注意力(memory-efficient attention)和Flash attention[42–45]。虽然Flash attention暂不支持掩码功能,但PyTorch和xFormers[46,47]已实现支持任意掩码操作的内存高效版本。理论上,内存高效方法的空间复杂度为$O(\sqrt{n})$(其中$n$表示序列/集合大小),时间复杂度则与标准注意力机制相同。由于ESA中除PMA模块的交叉注意力外均基于自注意力机制,我们的方法能直接受益于这些技术进步。此外,我们的交叉注意力操作在规模为$n$的完整集合与少量可学习种子($k\leq32$)之间进行,因此即便使用标准注意力,其复杂度也仅为$O(k n)$。具有线性内存复杂度的Flash attention(其自注意力计算速度可比标准注意力快达$\times$25倍)支持交叉注意力,我们借此实现性能提升。然而这些实现未针对图学习优化,我们发现最大瓶颈并非注意力计算,而是存储稠密边邻接掩码需要与边数量平方成正比的内存。一个明显但尚未实现的优化方案是采用稀疏张量格式存储掩码。此外,PyTorch中单个布尔值需占用整字节存储,导致理论内存使用量增加8倍。补充材料SI 5讨论了更多限制与优化可能。尽管存在这些限制,我们已成功训练出处理约30,000条边规模图数据(如表6中dd数据集)的ESA模型。

Table 1: The table reports root mean squared error on qm9 (RMSE is the standard metric for quantum mechanics) and $\mathrm{{R^{2}}}$ for dockstring (dock) and Molecule Net (mn), presented as mean $\pm$ standard deviation over 5 runs. The mean absolute error (MAE) is reported for pcqm4mv2 (MAE is the standard metric for this task) over a single run due to the size of the dataset. The lowest MAEs and RMSEs, and the highest $\mathrm{R^{2}}$ values are highlighted in bold. oom denotes out-of-memory errors. A complete table, including GCN and GIN, is provided in Supplementary Table 1.

表 1: 该表报告了qm9的均方根误差(RMSE是量子力学的标准指标)以及dockstring(dock)和Molecule Net(mn)的$\mathrm{{R^{2}}}$值,以5次运行的平均值$\pm$标准差呈现。由于数据集规模,pcqm4mv2报告了单次运行的平均绝对误差(MAE是该任务的标准指标)。最低的MAE和RMSE值以及最高的$\mathrm{R^{2}}$值以粗体标出。oom表示内存不足错误。完整表格(包括GCN和GIN)见补充表1。

目标 DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
μ ELUMO △E (R²) EHOMO 0.55 ± 0.01 0.44 ± 0.03 0.55 ± 0.01 0.48 ± 0.02 0.55 ± 0.01 0.46 ± 0.02 0.53 ± 0.01 0.48 ± 0.07 0.40 ± 0.63 ± 0.01 0.02 0.76 ± 0.02 0.45 ± 0.01 0.94 ± 0.17 0.60 ± 0.13 0.56 ± 0.00 0.40 ± 0.00
0.11 ± 0.00 0.11 ± 0.00 0.10 ± 0.00 0.10 ± 0.00 0.11 ± 0.00 0.12 ± 0.00 0.11 ± 0.00 0.10 ± 0.00
0.12 ± 0.00 0.12 ± 0.00 0.13 ± 0.00 0.11 ± 0.00 0.11 ± 0.00 0.13 ± 0.00 0.11 ± 0.00 0.11 ± 0.00
0.17 ± 0.00 0.16 ± 0.01 0.16 ± 0.00 0.14 ± 0.00 0.16 ± 0.00 0.18 ± 0.01 0.15 ± 0.01 0.15 ± 0.00
29.87 ± 0.34 30.91 ± 0.15 30.15 ± 0.36 28.50 ± 0.44 29.63 ± 0.35 31.54 ± 0.42 30.42 ± 1.19 28.33 ± 0.32
ZPVE 0.03 ± 0.00 0.03 ± 0.00 0.03 ± 0.00 0.03± 0.00 0.06 ± 0.03 0.03 ± 0.00 0.03 ± 0.01 0.03 ± 0.00
29.82 ± 6.32 27.82 ± 3.98 25.40 ± 3.62 31.64 ± 6.49 24.60 ± 4.70 15.48 ± 4.67 14.50 ± 4.34 4.78 ± 0.67
Uo 24.74 ± 6.31 30.91 ± 2.62 24.92 ± 4.18 25.04 ± :5.48 15.55 ± 8.23 12.45 ± 1.86 15.82 ± 4.26 6.80 ± 2.81
U H 34.02 ± 5.66 28.92 ± 4.11 23.42 ± 2.77 27.34 士 8.17 19.01 ± 4.65 11.42 ± 3.72 12.92 ± 4.29 6.02 ± 1.32
G 25.41 ± 4.86 31.49 ± 1.86 23.24 ± 2.54 22.31 士 3.45 31.20 ± 4.61 29.72 ± 8.79 13.08 ± 4.81 6.10 ± 1.33
CV 0.19 ± 0.01 0.19 ± 0.01 0.19 ± 0.00 0.18± 0.01 0.18 ± 0.02 0.18 ± 0.00 0.19 ± 0.03 0.16 ± 0.00
0.38 ± 0.02 0.39 ± 0.03 0.38 ± 0.02 0.35 ± 0.02 0.29± 0.00 0.30 ± 0.02 0.33 ± 0.05 0.24 ± 0.01
UATOM 0.40 ± 0.04 0.48 ± 0.07 0.40 ± 0.04 0.36± 0.02 0.30± 0.01 0.30 ± 0.00 0.33 ± 0.06 0.24 ± 0.01
HATOM GATOM 0.38 ± 0.04 0.38 ± 0.02 0.41 ± 0.06 0.37 士 0.03 0.30± 0.01 0.31 ± 0.01 0.36 ± 0.09 0.25 ± 0.00
A 0.37 ± 0.04 0.34 ± 0.02 0.33 ± 0.01 0.31 1.01 士 0.02 0.26 ± 0.01 0.27 ± 0.01 0.36 ± 0.08 0.22 ± 0.02
B 0.90 ± 0.06 0.21 ± 0.04 0.97 ± 0.12 0.21 ± 0.04 1.08 ± 0.16 0.26 ± 0.01 0.26 士 0.10 ±0.02 0.10 ± 64.88 ± 29.77 :0.03 3.82 ± 2.05 0.11 ± 0.01 1.42 ± 0.44 0.16 ± 0.05 0.75 ± 0.11
C 0.19 ± 0.05 0.27 ± 0.02 0.27 ± 0.01 0.28 ± 0.00 0.28 ± 0.01 0.12 ± :0.04 0.10 ± 0.02 0.12 ± 0.05 0.08 ± 0.01 0.05 ± 0.01
ESR2 F2 0.68 ± 0.00 0.67 ± 0.00 0.65 ± 0.00 0.70 ± 0.00 OOM 0.64 ± 0.01 0.68 ± 0.00 0.70 ± 0.00
KIT 0.89 ± 0.00 0.89 ± 0.00 0.89 ± 0.00 0.89 ± 0.00 0OM 0.87 ± 0.01 0.88 ± 0.00 0.89 ± 0.00
DOCK 0.83 ± 0.00 0.83 ± 0.00 0.83 ± 0.00 0.84 ± 0.00 OOM 0.80 ± 0.01 0.83 ± 0.00 0.84 ± 0.00
PARP1 0.92 ± 0.00 0.92 ± 0.00 0.92 ± 0.00 0.92 ± 0.00 OOM 0.91 ± 0.00 0.92 ± 0.01 0.93 ± 0.00
PGR 0.70 ± 0.00 0.68 ± 0.01 0.67 ± 0.01 0.72 ± 0.00 OOM 0.68 ± 0.01 0.70 ± 0.01 0.73 ± 0.00
FREESOLV 0.97 ± 0.00 0.96 ± 0.01 0.97 ± 0.01 0.95 ± 0.01 0.93 ± :0.00 0.93 ± 0.02 0.86 ± 0.03 0.98 ± 0.00
LIPO 0.81 ± 0.01 0.82 ± 0.01 0.82 ± 0.01 0.83 ± 0.01 0.61 ± :0.04 0.55 ± 0.02 0.79 ± 0.00 0.81 ± 0.01
NN ESOL 0.94 ± 0.01 0.93 ± 0.01 0.93 ± 0.00 0.94 ± 0.01 0.91 ± 0.02 0.89 ± 0.03 0.91 ± 0.00 0.94 ± 0.00
PCQM4MV2(↓) N/A N /A N /A N/A N /A N/ A N /A 0.0235

Table 2: The table reports mean absolute error (MAE) on the zinc dataset, presented as mean $\pm$ standard deviation over 5 runs. The best/lowest value is highlighted in bold. A complete table, including GCN, GIN, and DropGIN is provided in Supplementary Table 2.

表 2: 该表报告了锌数据集上的平均绝对误差 (MAE),以5次运行的平均值 $\pm$ 标准差表示。最佳/最低值以粗体标出。完整表格(包括GCN、GIN和DropGIN)见补充表2。

数据集 (↓) GAT GATv2 PNA Graphormer TokenGT GPS ESA ESA (PE)
ZINC 0.078 ± 0.01 0.079 ± 0.00 0.057 ± 0.01 0.036 ± 0.00 0.047 ± 0.01 0.024 ± 0.01 0.027 ± 0.00 0.017 ± 0.00

4 Results

4 结果

We perform a comprehensive evaluation of ESA on 70 different tasks, including domains such as molecular property prediction, vision graphs, and social networks, as well as different aspects of representation learning on graphs, ranging from node-level tasks with homophily and he t ero phil y graph types to modelling long range dependencies, shortest paths, and 3D atomic systems. We quantify the performance of our approach relative to 6 GNN baselines: GCN, GAT, GATv2, PNA, GIN, and DropGIN (more expressive than 1-WL), and 3 graph transformer baselines: Graphormer, TokenGT, and GraphGPS. All the details on hyper-parameter tuning, rationale for the selected metrics, and selection of baselines can be found in SI 7.1 to SI 7.3. In the remainder of the section, we summarize our findings across molecular learning, mixed graph-level tasks, node-level tasks, and ablations on the interleaving operator along with insights on time and memory scaling.

我们对ESA在70种不同任务上进行了全面评估,涵盖分子属性预测、视觉图和社会网络等领域,以及图表示学习的多个方面,包括同配性和异配性图类型的节点级任务,以及长程依赖建模、最短路径和3D原子系统等任务。我们将该方法的性能与6种GNN基线(GCN、GAT、GATv2、PNA、GIN和DropGIN(比1-WL更具表达能力)以及3种图Transformer基线(Graphormer、TokenGT和GraphGPS)进行了量化比较。关于超参数调优、所选指标的合理性及基线选择的详细信息,请参见SI 7.1至SI 7.3。在本节剩余部分,我们将总结在分子学习、混合图级任务、节点级任务以及交错算子消融实验中的发现,并探讨时间和内存扩展的见解。

4.1 Molecular Learning

4.1 分子学习

As learning on molecules has emerged as one of the most successful applications of graph learning, we present an in-depth evaluation including quantum mechanics, molecular docking, and various physical chemistry and biophysics benchmarks, as well as an exploration of learning on 3D atomic systems, transfer learning, and learning on large molecules of therapeutic relevance (peptides).

随着分子学习成为图学习最成功的应用之一,我们开展了深度评估,涵盖量子力学、分子对接、多种物理化学与生物物理学基准测试,并探索了3D原子系统学习、迁移学习以及对具有治疗相关性的生物大分子(如多肽)的学习。

QM9. We report results for all 19 qm9 [48] targets in Table 1, with GCN and GIN separately in Supplementary Table 1 due to space restrictions. We observe that on 15 out of 19 properties, ESA is the best performing model. The exceptions are the frontier orbital energies (HOMO and LUMO energy, HOMO-LUMO gap) and the dipole moment ( $\mu$ ), where PNA is slightly ahead of it. Other graph transformers are competitive relative to GNNs on many properties but vary in performance across tasks.

QM9. 我们在表1中报告了所有19个qm9 [48]目标的结果,由于篇幅限制,GCN和GIN的结果单独列在补充表1中。我们观察到,在19个性质中的15个上,ESA是表现最好的模型。例外情况是前沿轨道能量(HOMO和LUMO能量,HOMO-LUMO间隙)和偶极矩 ($\mu$),其中PNA略微领先。其他图Transformer (Transformer) 在许多性质上相对于GNN具有竞争力,但在不同任务中表现各异。

DOCKSTRING. dockstring [49] is a recent drug discovery data collection consisting of molecular docking scores for 260,155 small molecules and 5 high-quality targets from different protein families that were selected as a regression benchmark, with different levels of difficulty: parp1 (enzyme, easy), f2 (protease, easy to medium), kit (kinase, medium), esr2 (nuclear receptor, hard), and pgr (nuclear receptor, hard). We report results for the 5 targets in Table 1 (and Supplementary Table 1) and observe that ESA is the best performing method on 4 out of 5 tasks, with PNA slightly ahead on the medium-difficulty kit. TokenGT and GraphGPS do not generally match ESA or even PNA. Molecular docking scores also depend heavily on the 3D geometry, as discussed in the original paper [49], posing a difficult challenge for all methods. Interestingly, not only ESA but all tuned GNN baselines outperform the strongest method in the original manuscript (Attentive FP, a GNN based on attention [50]) despite using 20,000 less training molecules (which we leave out as a validation set). This illustrates the importance of evaluation relative to baselines with tuned hyper-parameters.

DOCKSTRING。dockstring [49] 是近期推出的药物发现数据集,包含260,155个小分子与5个来自不同蛋白质家族的高质量靶点的分子对接(docking)评分,这些靶点被选作回归基准并具有不同难度等级:parp1(酶类,简单)、f2(蛋白酶类,简单至中等)、kit(激酶类,中等)、esr2(核受体类,困难)和pgr(核受体类,困难)。我们在表1(及附表1)中报告了5个靶点的结果,观察到ESA在5项任务中有4项表现最佳,仅在中等难度的kit靶点上略逊于PNA。TokenGT和GraphGPS普遍未能达到ESA甚至PNA的水平。如原论文[49]所述,分子对接评分高度依赖3D几何结构,这对所有方法都构成了严峻挑战。值得注意的是,不仅ESA,所有经过调优的GNN基线模型都超越了原论文中最强方法(基于注意力机制的GNN模型Attentive FP [50]),尽管我们少用了20,000个训练分子(留作验证集)。这说明了相对于经过超参数调优的基线进行评估的重要性。

Table 3: The table reports Matthews correlation coefficient (MCC) for graph-level molecular classification tasks from Molecule Net and National Cancer Institute (nci), presented as mean $\pm$ standard deviation over 5 different runs. oom denotes out-of-memory errors. The highest mean values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 4 with MCC as metric, and in Supplementary Table 5 with accuracy.

表 3: 该表报告了来自Molecule Net和美国国家癌症研究所(nci)的图级分子分类任务的马修斯相关系数(MCC),以5次不同运行的平均值±标准差呈现。oom表示内存不足错误。最高平均值以粗体标出。完整表格(包括GCN和GIN)见补充表4(MCC指标)和补充表5(准确率指标)。

Data (↑) DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MN BBBP 0.68 ±0.02 0.74 ± 0.01 0.73±0.03 0.73 ±0.03 0.55± 0.01 0.58±0.07 0.70 ±0.04 0.84±0.01
BACE 0.65 ± 0.03 0.63 ± 0.02 0.64 ± 0.03 0.64 ± 0.02 0.52± 0.02 0.58 ± 0.03 0.62 ± 0.03 0.72 ± 0.02
HIV 0.46 ± 0.03 0.42 ± 0.06 0.34 ± 0.06 0.42 ± 0.04 OOM 0.46 ± 0.02 0.25 ± 0.21 0.53 ± 0.01
NC11 0.69±0.03 0.70±0.02 0.65±0.03 0.70±0.03 0.54 ± 0.02 0.53±0.03 0.70±0.03 0.75 ±0.01
NCI NC1109 0.68 ± 0.02 0.66 ± 0.01 0.66 ± 0.01 0.67 ± 0.02 0.50 ± 0.02 0.45 ± 0.03 0.62 ± 0.01 0.70 ± 0.01

Table 4: The table reports mean absolute error (MAE) and average precision (AP) for two long-range molecular benchmarks involving peptides. The molecular graph of a peptide is much larger than that of a small drug-like molecule and this makes the tasks well-suited for long-range benchmarking. All the results except the ones for ESA and TokenGT are extracted from [29]. The number of layers for pept-struct, respectively pept-func is given as $(\cdot/\cdot)$ .

表 4: 该表报告了涉及肽的两个长程分子基准测试的平均绝对误差 (MAE) 和平均精度 (AP)。肽的分子图比类药小分子大得多,这使得这些任务非常适合长程基准测试。除 ESA 和 TokenGT 的结果外,所有结果均摘自 [29]。pept-struct 和 pept-func 的层数分别表示为 $(\cdot/\cdot)$。

数据集 GCN (6/6) GIN (10/8) GPS (8/6) TokenGT (10/10) ESA (3/4)
PEPT-STR (MAE ↓) 0.2460±0.0007 0.2473±0.0017 0.2509±0.0014 0.2489±0.0013 0.2453±0.0003
PEPT-FN (AP↑) 0.6860±0.0050 0.6621±0.0067 0.6534±0.0091 0.6263±0.0117 0.6863±0.0044

Molecule Net and NCI. We report results for a varied selection of three regression and three classification benchmarks from Molecule Net, as well as two benchmarks from the National Cancer Institute (NCI) consisting of compounds screened for anti-cancer activity (Tables 1 and 3, and Supplementary Tables 1 and 4). We also report the accuracy in Supplementary Table 5. Except hiv, these datasets pose a challenge to graph transformers due to their small size ( $<$ 5,000 compounds). With the exception of the Lip oph ili city (lipo) dataset from Molecule Net, we observe that ESA is the preferred method, and often by a significant margin, for example on bbbp and bace, despite their small size (2,039, respectively 1,513 total compounds before splitting). On the other hand, Graphormer and TokenGT perform poorly, possibly due to the small-sample nature of these tasks. GraphGPS is closer to the top performers but still under performs compared to GAT(v2) and PNA. These results also show the importance of appropriate evaluation metrics, as the accuracy on hiv for all methods is above 97% (Supplementary Table 5), but the MCC (Table 3) is significantly lower.

Molecule Net与NCI。我们报告了来自Molecule Net的三个回归和三个分类基准的多样化选择结果,以及来自美国国家癌症研究所(NCI)的两个基准,这些基准包含筛选抗癌活性的化合物(表1和表3,及补充表1和表4)。补充表5中还报告了准确率。除hiv外,这些数据集由于规模较小( $<$ 5,000个化合物)对图Transformer提出了挑战。除Molecule Net中的Lipophilicity(lipo)数据集外,我们观察到ESA是首选方法,且通常优势显著,例如在bbbp和bace上(拆分前分别为2,039和1,513个化合物)。另一方面,Graphormer和TokenGT表现不佳,可能是由于这些任务的少样本特性。GraphGPS更接近顶级方法,但仍逊色于GAT(v2)和PNA。这些结果也显示了适当评估指标的重要性,因为所有方法在hiv上的准确率都超过97%(补充表5),但MCC(表3)明显更低。

PCQM4MV2. pcqm4mv2 is a quantum chemistry benchmark introduced as a competition through the Open Graph Benchmark Large-Scale Challenge (OGB-LSC) project [51]. It consists of 3,378,606 training molecules with the goal of predicting the DFT-calculated HOMO-LUMO energy gap from 2D molecular graphs. It has been widely used in the literature, especially to evaluate graph transformers [25, 26, 52]. Since the test splits are not public, we use the available validation set (73,545 molecules) as a test set, as is often done in the literature, and report results on it after training for 400 epochs. We report results from a single run, as it is common in the field due to the large size of the dataset [25, 26, 52]. For the same reason, we do not run our baselines and instead chose to focus on the state-of-the-art results published on the official leader board∗. At the time of writing (November 2024), the best performing model achieves a validation set MAE of 0.0671 [53]. Here, ESA achieves a MAE of 0.0235, which is almost 3 times lower. It it worth noting that the top 3 methods for this dataset are all bespoke architectures designed for molecular learning, for example Uni-Mol+ [54], Transformer-M [55], and TGT [53]. In contrast, ESA is general purpose (does not use any information or technique specifically for molecular learning), uses only the 2D input graph, and does not use any positional or structural encodings.

PCQM4MV2

pcqm4mv2是由开放图基准大规模挑战赛(OGB-LSC)项目[51]推出的量子化学基准测试数据集。该数据集包含3,378,606个训练分子,目标是从二维分子图中预测DFT计算的HOMO-LUMO能隙。该数据集在学术界被广泛使用,尤其用于评估图Transformer模型[25, 26, 52]。由于测试集未公开,我们按照文献惯例使用现有验证集(73,545个分子)作为测试集,并在训练400轮后报告结果。由于数据集规模庞大,我们与领域内常见做法一致[25, 26, 52],仅报告单次运行结果。出于同样原因,我们没有重新运行基线模型,而是选择聚焦官方排行榜*公布的最新成果。截至本文撰写时(2024年11月),性能最佳的模型在验证集上达到0.0671的MAE[53]。而ESA模型取得了0.0235的MAE,性能提升近3倍。值得注意的是,该数据集当前排名前三的方法都是专门针对分子学习设计的架构,例如Uni-Mol+[54]、Transformer-M[55]和TGT[53]。相比之下,ESA是通用架构(未使用任何分子学习专用信息或技术),仅使用二维输入图,且不依赖任何位置或结构编码。

ZINC. We report results on the full zinc dataset with 250,000 compounds (Table 2), which is commonly used for generative purposes [56, 57]. This is one of the only benchmarks where the graph transformer baselines (Graphormer, TokenGT, GraphGPS) convincingly outperformed strong GNN baselines. ESA, without positional encodings, slightly under performs compared to GraphGPS, which uses random-walk structural encodings (RWSE). This type of encoding is known to be beneficial in molecular tasks, and especially for zinc [27, 58]. Thus, we also evaluated an ESA $^+$ RWSE model, which increased relative performance by almost 40%. While there is no leader board available for the full version of zinc, recently Ma et al. [31] evaluated their own method (GRIT) against 11 other baselines, including higher-order GNNs, with the best reported MAE being $\mathbf{0.023\pm0.001}$ . This shows that ESA can already almost match state-of-the-art models without structural encodings, and significantly improves upon this result when augmented.

ZINC。我们在包含25万种化合物的完整ZINC数据集上报告结果(表2),该数据集通常用于生成目的[56,57]。这是少数几个图Transformer基线模型(Graphormer、TokenGT、GraphGPS)明显优于强GNN基线的基准之一。未使用位置编码的ESA略逊于采用随机游走结构编码(RWSE)的GraphGPS。已知此类编码对分子任务(尤其是ZINC)有益[27,58]。因此我们还评估了ESA$^+$RWSE模型,其相对性能提升近40%。虽然完整版ZINC没有公开排行榜,但最近Ma等人[31]评估了其方法(GRIT)与11个其他基线(包括高阶GNN)的对比,最佳报告MAE为$\mathbf{0.023\pm0.001}$。这表明ESA即使不加结构编码也能接近最先进模型,增强后更能显著提升结果。

Table 5: A summary of the transfer learning performance on qm9 for HOMO and LUMO properties, presented as mean $\pm$ standard deviation over 5 different runs. The metric is the root mean squared error (RMSE). All the models use the 3D atomic coordinates and atom types as inputs and no other node or edge features. ‘Strat.’ stands for strategy and specifies the type of learning: GW only (no transfer learning), inductive, or trans duct ive. The lowest values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 3.

表 5: qm9数据集上HOMO和LUMO性质迁移学习性能汇总,结果为5次运行的平均值$\pm$标准差。评估指标为均方根误差(RMSE)。所有模型均使用3D原子坐标和原子类型作为输入,不使用其他节点或边特征。"Strat."表示策略,指定学习类型:仅GW(无迁移学习)、归纳式(inductive)或传导式(transductive)。最低值以粗体标出。完整表格(含GCN和GIN)见补充材料表3。

任务 Strat. DropGIN GAT GATv2 PNA Grph. TokenGT GPS ESA
HOMO GW 0.162 ± 0.00 0.159 ± 0.00 0.157 ± 0.00 0.151 ± 0.00 0.179 ± 0.01 0.200 ± 0.01 0.162 ± 0.00 0.152 ± 0.00
Ind. 0.136 ± 0.00 0.131 ± 0.00 0.133 ± 0.00 0.132 ± 0.00 0.134 ± 0.00 0.156 ± 0.00 0.151 ± 0.00 0.131 ± 0.00
Trans. 0.126 ± 0.00 0.123 ± 0.00 0.124 ± 0.00 0.121 ± 0.00 0.125 ± 0.00 0.137 ± 0.00 0.147 ± 0.00 0.119 ± 0.00
LUMO GW 0.180 ± 0.00 0.181 ± 0.00 0.178 ± 0.00 0.174 ± 0.00 0.190 ± 0.01 0.204 ± 0.01 0.178 ± 0.00 0.174 ± 0.00
Ind. 0.159 ± 0.00 0.156 ± 0.00 0.157 ± 0.00 0.156 ± 0.00 0.151 ± 0.00 0.165 ± 0.00 0.167 ± 0.00 0.150 ± 0.00
Trans. 0.157 ± 0.00 0.153 ± 0.00 0.153 ± 0.00 0.153 ± 0.00 0.147 ± 0.00 0.156 ± 0.00 0.169 ± 0.00 0.146 ± 0.00

Long-range peptide tasks. Graph learning with transformers has traditionally been evaluated on longrange graph benchmarks (LRGB) [59]. However, it was recently shown that simple graph neural networks outperform most attention-based methods [29]. We selected two long-range benchmarks involving peptide property prediction: peptides-struct and peptides-func. From the LRGB collection, these two stand out due to having the longest average shortest path ( $20.89\pm9.79$ versus $10.74\pm0.51$ for the next) and the largest average diameter $(56.99\pm28.72\$ versus $27.62\pm2.13$ ). We report ESA results against tuned GNNs and GraphGPS models from [59], as well as our own optimised TokenGT model (Table 4). Despite using only half the number of layers as other methods or less, ESA outperformed these baselines and matched the second model on the pept-struct leader board, and is within the top 5 for pept-func (as of November 2024).

长程肽任务。传统上,基于Transformer的图学习主要在长程图基准(LRGB) [59]上进行评估。但近期研究表明,简单的图神经网络(GNN)能超越大多数基于注意力机制的方法[29]。我们选取了两个涉及肽属性预测的长程基准:peptides-struct和peptides-func。在LRGB数据集中,这两个任务因具有最长的平均最短路径($20.89\pm9.79$,而次长仅为$10.74\pm0.51$)和最大的平均直径($56.99\pm28.72$,而次大为$27.62\pm2.13$)而突出。表4展示了ESA方法与[59]中调优的GNN、GraphGPS模型以及我们优化的TokenGT模型的对比结果。尽管ESA使用的层数仅为其他方法的一半或更少,其性能仍超越基线模型,并在pept-struct排行榜上与第二名持平,同时在pept-func任务中保持前五名(截至2024年11月)。

Learning on 3D atomic systems. We adapt a subset from the Open Catalyst Project (OCP) [60, 61] for evaluation, with the full steps in SI 9, including deriving edges from atom positions based on a cutoff, pre-processing specific to crystal systems, and encoding atomic distances using Gaussian basis functions. As a prototype, we compare NSA (MAE of $\mathbf{0.799\pm0.008}$ ) against Graphormer $0.839\pm0.005$ ), one of the best models that have been used for the Open Catalyst Project [62]. These encouraging results on a subset of OCP motivated us to further study modelling 3D atomic system through the lens of transfer learning.

学习3D原子系统。我们采用Open Catalyst Project (OCP) [60, 61]的子集进行评估,完整步骤见SI 9,包括基于截断值从原子位置推导边缘、针对晶体系统的特定预处理,以及使用高斯基函数编码原子距离。作为原型,我们将NSA (MAE为$\mathbf{0.799\pm0.008}$)与Graphormer ($0.839\pm0.005$)进行对比,后者是Open Catalyst Project [62]中使用的最佳模型之一。这些在OCP子集上取得的鼓舞性结果,促使我们通过迁移学习的视角进一步研究3D原子系统的建模。

Transfer learning on frontier orbital energies. We follow the recipe recently outlined for drug discovery and quantum mechanics by [32] and leverage a recent, refined version of the qm9 HOMO and LUMO energies [63] that provides alternative DFT calculations and new calculations at the more accurate GW level of theory. As outlined by [32], transfer learning can occur trans duct iv ely or inductively. The transfer learning scenario is important and is detailed in SI 4. We use a 25K/5K/10K train/validation/test split for the high-fidelity GW data, and we train separate low-fidelity DFT models on the entire dataset (trans duct ive) or with the high-fidelity test set molecules removed (inductive). Since the HOMO and LUMO energies depend to a large extent on the molecular geometry, we use the 3D-aware version of ESA from the previous section, and adapt all of our baselines to the 3D setup. Our results are reported in Table 5. Without transfer learning (strategy ‘GW’ in Table 5), ESA and PNA are almost evenly matched, which is already an improvement since PNA was better for frontier orbital energies without 3D structures (Table 1), while the graph transformers perform poorly. Employing transfer learning all the methods improve significantly, but ESA outperforms all baselines for both HOMO and LUMO, in both trans duct ive and inductive tasks.

前沿轨道能的迁移学习。我们遵循[32]近期针对药物发现和量子力学提出的方法,并利用qm9 HOMO和LUMO能级的最新优化版本[63],该版本提供了替代的DFT计算以及在更精确的GW理论层面的新计算结果。如[32]所述,迁移学习可通过传导式(transductive)或归纳式(inductive)进行。该迁移学习场景非常重要,详见SI 4。我们对高精度GW数据采用25K/5K/10K的训练/验证/测试集划分,并在完整数据集上训练独立的低精度DFT模型(传导式),或移除高精度测试集分子后训练(归纳式)。由于HOMO和LUMO能级很大程度上取决于分子几何结构,我们使用前文所述的3D感知版ESA,并将所有基线方法适配至3D架构。结果如表5所示:未采用迁移学习时(表5中"GW"策略),ESA与PNA表现接近,这已是进步——因为在无3D结构时PNA对前沿轨道能更具优势(表1),而图Transformer模型表现较差;采用迁移学习后所有方法均有显著提升,但ESA在HOMO和LUMO的传导式与归纳式任务中均优于所有基线方法。

4.2 Mixed Graph-level Tasks

4.2 混合图级别任务

Other than molecular learning, we examine a suite of graph-level benchmarks from various domains, including computer vision, bioinformatics, synthetic graphs, and social graphs. Our results are reported in Table 6 and Supplementary Table 7, where ESA generally outperforms all the baselines. In the context of state-of-theart models from online leader boards, of particular note are the accuracy results (provided in Supplementary Table 7) on the vision datasets, where on mnist ESA matches the best performing model [64], and is in the top 5 on cifar10 at the time of submission (November 2024), without positional or structural encodings or other helper mechanisms. Similarly, on the malnettiny dataset of function call graphs popularised by GraphGPS [27], we achieve the highest recorded accuracy.

除分子学习外,我们还测试了来自计算机视觉、生物信息学、合成图和社会图等多个领域的图级基准。结果如表6和附表7所示,其中ESA普遍优于所有基线方法。值得注意的是,在在线排行榜的当前最优模型背景下,ESA在视觉数据集上的准确率结果(见附表7)表现突出:在mnist数据集上匹配了性能最佳的模型[64],在提交时(2024年11月)的cifar10数据集上位列前五,且未使用位置编码、结构编码或其他辅助机制。同样,在GraphGPS[27]推广的函数调用图数据集malnettiny上,我们取得了迄今最高的记录准确率。

4.3 Node-level Benchmarks

4.3 节点级基准测试

Node-level tasks are an interesting challenge for our proposed approach as the PMA module is not needed and node-level propagation is the most natural strategy of learning node representations. To this end, we did prototype with an edge-to-node pooling module which would allow ESA to learn node embeddings, with good results. However, the approach does not currently scale to the millions of edges that some he t ero philo us graphs have. For these reasons, we revert to the simpler node-set attention (NSA) formulation. Table 7 summarizes our results, indicating that NSA performs well on all 11 node-level benchmarks, including homo philo us (citation), he t ero philo us, and shortest path tasks. We have adapted Graphormer and TokenGT for node classification as this functionality was not originally available, although they require integer node and edge features which restricts their use on some datasets (denoted by n/a in Table 7). NSA achieves the highest MCC score on homo philo us and he t ero philo us graphs, but GCN has the highest accuracy on cora (Supplementary Table 9). Some methods, for instance GraphSAGE [67] are designed to perform particularly well under he t ero phil y and to account for that we include GraphSAGE [67] and Graph Transformer [33] as the two top performing baselines from Platonov et al. [66] in our extended results in Supplementary Table 10. With the exception of the amazon ratings datasets, we outperform the two baselines by a noticeable margin. Our results on chameleon stand out in particular, as well as on the shortest path benchmarks, where other graph transformers are unable to learn and even PNA fails to be competitive.

节点级任务对我们提出的方法来说是一个有趣的挑战,因为不需要PMA模块,而节点级传播是最自然的节点表示学习策略。为此,我们确实尝试了边缘到节点的池化模块原型,使ESA能够学习节点嵌入,并取得了良好效果。但该方法目前无法扩展到某些异质图拥有的数百万条边。基于这些原因,我们回归到更简单的节点集注意力(NSA)方案。表7总结了我们的结果,表明NSA在所有11个节点级基准测试中表现良好,包括同质(引用)、异质和最短路径任务。我们调整了Graphormer和TokenGT用于节点分类(该功能原本不可用),但它们需要整数节点和边特征,这限制了在某些数据集上的使用(表7中标记为n/a)。NSA在同质图和异质图上取得了最高MCC分数,但GCN在cora数据集上准确率最高(补充表9)。某些方法(如GraphSAGE [67])专为在异质性下表现优异而设计,为此我们在补充表10的扩展结果中纳入了Platonov等人[66]的两个最佳基线方法GraphSAGE [67]和Graph Transformer [33]。除amazon评分数据集外,我们以明显优势超越了这两个基线。我们在chameleon数据集和最短路径基准上的表现尤为突出——其他图Transformer模型无法有效学习,甚至PNA也缺乏竞争力。

Table 6: The table reports Matthews correlation coefficient (MCC) for graph-level classification tasks from various domains, presented as mean $\pm$ standard deviation over 5 different runs. oom denotes out-of-memory errors, and n/a that the model is unavailable (e.g., node/edge features are not integers, which are required for Graphormer and TokenGT). The poor performance of GraphGPS on malnettiny in Table 6 can be explained by our use of one-hot degrees as node features for datasets lacking pre-computed features, while GraphGPS originally used the more informative local degree profile [65]. The highest mean values are highlighted in bold. A complete table, including GCN and GIN, is provided in Supplementary Table 6 with MCC as metric, and in Supplementary Table 7 with accuracy.

表 6: 该表报告了不同领域图级分类任务的马修斯相关系数 (MCC),以5次运行的平均值 ± 标准差呈现。oom表示内存不足错误,n/a表示模型不可用(例如,Graphormer和TokenGT需要整数节点/边特征)。表6中GraphGPS在malnettiny上的较差表现可归因于我们对缺乏预计算特征的数据集使用独热编码度作为节点特征,而GraphGPS原本使用信息量更大的局部度分布 [65]。最高平均值以粗体标出。完整表格(包括GCN和GIN)在补充表6(以MCC为指标)和补充表7(以准确率为指标)中提供。

Data DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MALNETTINY 0.90 ± 0.01 0.90 ± 0.01 0.90 ± 0.01 0.91 ± 0.01 OOM 0.78±( 0.01 0.79 ± 0.01 0.93 ± 0.00
MNIST 0.97 干 0.00 0.97 士 0.00 0.98 士 0.00 0.98 士 0.00 N/A N/A 0.98 土 0.00 0.99 士 0.00
CIFAR10 0.61 ± 0.01 0.66 士 0.01 0.66 士 0.01 0.69 ± 0.01 N/A N/A 0.71 士 0.01 0.73 士( 0.00
ENZYMES 0.58 ± 0.04 0.75 ± 0.02 0.74 )干1 0.02 0.68 ± 0.03 N / A N/A 0.73 士 0.05 0.75 ±0.01
10. B PROTEINS 0.46 ± 0.01 0.46 ± 0.04 0.49 ± 0.04 0.47 ± 0.07 N/A N/A 0.44 0.02 0.59 ± 0.02
DD 0.54 ± 0.07 0.47 ± 0.05 0.53 ± 0.03 0.56 ± 0.08 OOM 0.46 ± 0.05 0.60 ± 0.05 0.65 ± 0.03
HLN SYNTH 1.00 ± 0.00 1.00 土 0.00 1.00 土 0.00 1.00 ± 0.00 N / A N/A 1.00 土 0.00 1.00 1.00
YN SYNT N. 0.97 ± 0.05 0.76: ± 0.07 0.91 0.07 1.00 ± 0.00 N/A N/A 1.00 土 0.00 1.00 士( 0.00
S SYNTHIE 0.94 ± 0.02 0.70 ± 0.04 0.80 ± 0.04 0.88 ± 0.06 N/A N/A 0.95 ± 0.02 0.95 ± 0.02
IMDB-B 0.61 ± 0.06 0.69 ±0.04 0.60 0.05 0.56 ± 0.07 0.56 ± 0.61 ± 0.05 0.60 士 0.05 0.74 ± 0.03
CIAL IMDB-M 0.12 ± 0.10 0.20 ± 0.05 0.20 ± 0.03 0.03 ± 0.07 0.22 干 0.20 ± 0.02 0.23 士 0.02 0.25 ±0.03
TWITCHE. 0.38 ± 0.01 0.37 ± 0.01 0.37 ± 0.01 0.08 ± 0.16 0.03 0.39 干 0.00 0.00 0.39 ± 0.00 0.40 干 0.00 0.40 ± 0.00
S RDT.THR. 0.56 ± 0.01 0.53 ± 0.02 0.54 ± 0.02 0.11 ± 0.23 0.57 ± 0.56 ± 0.00 0.57 ± 0.00 0.57 ± 0.00

Table 7: The table reports Matthews correlation coefficient (MCC) for 11 node-level classification tasks, presented as mean $\pm$ standard deviation over 5 different runs. The number of nodes for the shortest path (sp) benchmarks is given in parentheses (based on randomly-generated “infected Erd?s-Rényi (ER) graphs; details are provided in SI 8). For squirrel and chameleon, we used the filtered datasets introduced by Platonov et al. [66] to fix existing data leaks. The highest mean values are highlighted in bold. Additional he t ero phil y results are provided in Supplementary Table 10. A complete table, including GIN and DropGIN, is provided in Supplementary Table 8 with MCC as the performance metric, and in Supplementary Table 9 with accuracy.

表 7: 该表报告了11个节点级分类任务的马修斯相关系数 (MCC),以5次不同运行的平均值 $\pm$ 标准差呈现。最短路径 (sp) 基准测试的节点数在括号中给出(基于随机生成的“感染Erd?s-Rényi (ER) 图”;详细信息见SI 8)。对于squirrel和chameleon,我们使用了Platonov等人[66]引入的过滤数据集来修复现有的数据泄露问题。最高平均值以粗体标出。补充表10提供了额外的异质性结果。完整表格(包括GIN和DropGIN)在补充表8中提供(以MCC作为性能指标),在补充表9中提供(以准确率作为指标)。

Data (↑) GCN GAT GATv2 PNA Graphormer TokenGT GPS NSA
PPI 0.98 ± 0.00 0.99 ± 0.00 0.98 0.01 0.99 ± 0.00 N/A N/A N/A 0.99 0.00
CITESEER 0.61 干 0.01 0.59 士 0.03 0.61 0.01 0.51 干 0.03 OOM 0.38 ± 0.02 0.54 ± 0.01 0.63 士 0.00
CIT CORA 0.77 ± 0.01 0.75 0.01 0.73: ± 0.01 0.64 ± 0.03 OOM 0.37 ± 0.18 0.64 ± 0.04 0.77 ± 0.00
ROMAN E. 0.47 ± 0.00 0.74 士 0.01 0.76 ± 0.00 0.86 ± 0.00 N/A N/A 0.84 ± 0.01 0.87 ± 0.00
O AMAZON R. 0.18 ± 0.00 0.26 0.01 0.25 ± 0.01 0.21 ± 0.02 N/A N/A 0.11 ± 0.14 0.34 ± 0.01
MINESWEEPER 0.30 ± 0.00 0.48 土 0.02 0.51 ± 0.01 0.62 ± 0.04 N/A N/A 0.56 ± 0.01 0.69 ± 0.00
HE TOLOKERS 0.30 ± 0.01 0.38 土 0.01 0.39 ± 0.00 0.35± 0.05 N/A N/A 0.35 ± 0.02 0.43 ± 0.00
SQUIRREL 0.20 ± 0.01 0.24 土 0.02 0.24 ± 0.01 0.22± 0.01 0.10 ± 0.09 0.18 ± 0.02 0.23 ± 0.02 0.29 ± 0.01
CHAMELEON 0.32 ± 0.01 0.24 干 0.06 0.28 ± 0.03 0.27 ± 0.03 0.25± 0.04 0.26± 0.04 0.30 ± 0.08 0.39 9±0.02
SP ER (15K) 0.22 2土( 0.02 0.32 士 0.00 0.32± 0.00 0.54 ± 0.09 OOM 0.06 ± 0.00 0.18 ± 0.04 0.92 士( 0.01
ER (30K) 0.09 ± 0.03 0.10 ± 0.06 0.10 ± 0.06 0.42 ± 0.05 OOM OOM OOM 0.87 ± 0.01

Table 8: Example of multiple ESA configurations and their impact on performance via three different kinds of benchmarks: a graph-level molecular task (bbbp), a graph-level vision task (mnist), and a node-level he t ero philo us task (chameleon). In the model configurations, $^{6}\mathrm{M}^{1}$ ’ denotes a MAB, ‘S’ a SAB before the PMA module and ‘S’ afterwards, and ‘P’ is the PMA module. The performance metric is the Matthews correlation coefficient (MCC).

表 8: 多种ESA配置示例及其通过三类基准测试对性能的影响:图级分子任务(bbbp)、图级视觉任务(mnist)和节点级异质任务(chameleon)。模型配置中,$^{6}\mathrm{M}^{1}$'表示MAB,'S'表示PMA模块前的SAB,'P'表示PMA模块。性能指标为马修斯相关系数(MCC)。

Dataset Model MCC (↑) Dataset Model MCC (↑) Dataset Model MCC (↑)
MMMMSPS 0.845 SSMMMMMMMPS 0.986 MSMSMS 0.422
BBBP MMMSPSS 0.835 SMSMSMSMSMP 0.983 CHAME- SSMMSS 0.384
SSSMMPS 0.812 MNIST SSSMMMSSSP 0.982 LEON SMSMSM 0.378
MMMMMP 0.782 MSMSMSMSMSP 0.980 M M M M M 0.359
SMSMSP 0.768 0.980 SSMMMM 0.351

4.4 Effects of Varying the Layer Order and Type

4.4 层序与类型变化的影响

In Table 8, we summarise the results of an ablation relative to the interleaving operator with different order and types of layers in our architecture. Smaller datasets perform well with 4 to 6 feature extraction layers, while larger datasets with more complex graphs like mnist benefit from up to 10 layers. We have generally observed that the top configurations tend to include self-attention layers at the front, with masked attention layers in the middle and self-attention layers at the end, surrounding the PMA readout. Naive configurations such as all-masked layers or simply alternating masked and self-attention layers do not tend to be optimal for graph-level prediction tasks. This ablation experiment demonstrates the importance of vertical combination of masked and self-attention layers for the performance of our model.

在表8中,我们总结了关于架构中不同顺序和类型层的交错操作符的消融实验结果。较小数据集在4到6个特征提取层时表现良好,而具有更复杂图结构的大型数据集(如mnist)则受益于多达10层的设计。我们普遍观察到,最优配置往往在前端包含自注意力层,中间为掩码注意力层,末端再放置自注意力层,从而环绕PMA读出模块。纯掩码层或简单交替掩码与自注意力层的朴素配置通常不适用于图级预测任务。该消融实验证明了掩码层与自注意力层垂直组合对模型性能的重要性。


Figure 4: The elapsed time for training a single epoch for different datasets (in seconds), and the maximum allocated memory during this training epoch (GB). Different configurations are tested, while varying hidden dimensions and number of layers. qm9 has around 130K small molecules (a maximum of 29 nodes and 56 edges), dockstring has around 260K graphs that are around 6 times larger, and finally mnist has 70K graphs which are around 11-12 times larger than qm9. We use dummy integer features for TokenGT when benchmarking mnist. Graphormer runs out of memory for dockstring and mnist.

图 4: 不同数据集训练单轮耗时(秒)及该训练轮次最大内存占用(GB)。测试了不同隐藏维度和层数的配置。qm9包含约13万个小分子(最多29个节点和56条边),dockstring包含约26万个规模约为qm9 6倍的图,mnist包含7万个规模约为qm9 11-12倍的图。基准测试mnist时,TokenGT使用虚拟整数特征。Graphormer在dockstring和mnist上出现内存不足。

4.5 Time and Memory Scaling

4.5 时间和内存扩展

The theoretical properties of ESA are discussed in Methods, Section 3.3, where we also cover deep learning library limitations and possible optimisation s (also see SI 5). Here, we have empirically evaluated the time and memory scaling of ESA against all 9 baselines. For a fair evaluation, we implement Flash attention for TokenGT as it was not originally supported. GraphGPS natively supports Flash attention, while Graphormer requires specific modifications to the attention matrix which are not currently supported.

ESA的理论特性在方法部分的第3.3节中讨论,我们还涵盖了深度学习库的局限性及可能的优化方案(另见SI 5)。本文通过实证评估了ESA与全部9个基线方法在时间和内存占用上的扩展性。为确保公平性,我们为TokenGT实现了Flash attention机制(该机制原版本未支持)。GraphGPS原生支持Flash attention,而Graphormer需要对注意力矩阵进行特定修改(当前版本尚未支持该功能)。

We report the training time for a single epoch and the maximum allocated memory during training for qm9 and mnist in Figure 4, and dockstring in Supplementary Figure 1. In terms of training time, GCN and GIN are consistently the fastest due to their simplicity and low number of parameters (also see Figure 6a). ESA is usually the next fastest, followed by other graph transformers. The strong GNN baselines, particularly PNA, and to an extent GATv2 and GAT, are among the slowest methods. In terms of memory, GCN and GIN are again the most efficient, followed by GraphGPS and TokenGT. ESA is only slightly behind, with PNA, DropGIN, GATv2, and GAT all being more memory intensive, particularly DropGIN.

我们在图4中报告了qm9和mnist的单轮训练时间及训练期间最大内存占用,dockstring数据则在补充图1中展示。就训练时间而言,由于结构简单且参数量少(另见图6a),GCN和GIN始终是最快的。ESA通常是第二快的,其次是其他图Transformer模型。强GNN基线方法(尤其是PNA)以及GATv2和GAT在某种程度上属于最慢的方法。内存方面,GCN和GIN同样最节省资源,其次是GraphGPS和TokenGT。ESA略逊一筹,而PNA、DropGIN、GATv2和GAT都更耗内存,其中DropGIN尤为明显。

We also order all methods according to their achieved performance and illustrate their rank relative to the total time spent training and the maximum memory allocated during training (Figure 5). ESA occupies the top left corner in the time plot, confirming its efficiency. The results are more spread out regarding memory, however the performance is still remarkable considering that ESA works over edges, whose number rapidly increases relative to the number of nodes that dictates the scaling of all other methods.

我们还根据各方法实现的性能进行排序,并展示它们在总训练时间和训练期间最大内存分配上的相对排名(图5)。ESA在时间图中位于左上角,证实了其高效性。在内存方面结果分布较分散,但考虑到ESA基于边(其数量相对节点数快速增长,而节点数决定了其他所有方法的扩展性)运作,其性能仍然非常出色。

Finally, we report the number of parameters for all methods and their configurations (Figure 6a). Apart from GCN, GIN, and DropGIN, ESA has the lowest number of parameters. GAT and GATv2 have rapidly increasing parameter counts, likely due to the concatenation that happens between attention heads, while PNA also increases significantly quicker than ESA. Lastly, we demonstrate and discuss one of the bottlenecks

最后,我们报告了所有方法及其配置的参数数量(图 6a)。除 GCN、GIN 和 DropGIN 外,ESA 的参数数量最低。GAT 和 GATv2 的参数数量增长迅速,可能是由于注意力头之间的拼接操作,而 PNA 的增长速度也明显快于 ESA。最后,我们展示并讨论了其中一个瓶颈


Figure 5: All methods illustrated according to their achieved performance (rank) versus the total time spent training (minutes) in Panel (a) and the maximum allocated memory (GB) in Panel (b).

图 5: 所有方法根据其达到的性能(排名)与训练总耗时(分钟)在面板(a)中的对比,以及最大分配内存(GB)在面板(b)中的对比。


(a) The number of parameters (in millions) for all the methods (b) The time and memory util is ation for a selection of Graphormer and different configurations in terms of the hidden dimension configurations. The varying parameters are the number of layers, and the number of layers. and the number of input graphs.

图 1:
(a) 所有方法的参数量(单位:百万) (b) Graphormer及不同隐藏维度配置下的时间和内存利用率对比。变量参数包括层数和输入图数量。

Figure 6: The number of parameters for all methods in Panel (a) and an illustration of the time and memory bottleneck in Graphormer in Panel (b). For Graphormer, even training on a single graph from hiv takes over 10 seconds and slightly under 5GB. Increasing the number of graphs to 8 and setting the batch size to 8 to ensure parallel computation increases the running time to around 80 seconds and slightly under 30GB. Extrapolating these numbers for a batch size of 8 to the entire dataset results in a training time of around 5 days for a single epoch.

图 6: 面板 (a) 中所有方法的参数量,以及面板 (b) 中 Graphormer 的时间和内存瓶颈说明。对于 Graphormer,即使在 hiv 数据集上训练单个图也需要超过 10 秒和略低于 5GB 内存。将图数量增加到 8 并设置批次大小为 8 以确保并行计算时,运行时间增至约 80 秒且内存占用略低于 30GB。若将批次大小 8 外推至整个数据集,单个训练周期将耗时约 5 天。

in Graphormer for the dataset hiv (Figure 6b). This lack of efficiency is likely caused by the edge features computed by Graphormer, which depend quadratically on the number of nodes in the graph, and on the maximum shortest path in the graph.

在图神经网络Graphormer处理hiv数据集时出现了效率不足的情况(图6b)。这种效率缺失很可能源于Graphormer计算的边特征(edge features),这些特征的计算复杂度与图中节点数量的平方成正比,同时受图中最长最短路径(maximum shortest path)的影响。

5 Discussion

5 讨论

We presented an end-to-end attention-based approach for learning on graphs and demonstrated its effectiveness relative to tuned graph neural networks and recently proposed graph transformer architectures. The approach is free of expensive pre-processing steps and numerous other additional components designed to improve the inductive bias or expressive power (e.g., shortest paths, centrality, spatial and structural encodings, virtual nodes as readouts, expander graphs, transformations via interaction graphs, etc.). This shows huge potential and plenty of room for further fine-tuning and task-specific improvements. The interleaving operator inherent to ESA allows for vertical combination of masked and self-attention modules for learning effective token (i.e., edge or node) representations, leveraging relational information specified by input graphs while at the same time allowing to expand on this prior structure via self-attention. The recently released Flex Attention feature in PyTorch allows for further extensions via rich masking operators informed by graph structure while leveraging advancements in exact and efficient attention that have enabled our work.

我们提出了一种基于注意力机制的端到端图学习方法,并证明了其相对于调优图神经网络和近期提出的图Transformer架构的有效性。该方法无需昂贵的预处理步骤,也不依赖众多旨在提升归纳偏置或表达能力的附加组件(如最短路径、中心性、空间与结构编码、作为读出器的虚拟节点、扩展图、通过交互图进行变换等)。这显示出巨大的潜力,并为后续细化和任务特定改进提供了充足空间。ESA固有的交错操作符能够垂直整合掩码注意力与自注意力模块,从而学习有效的Token(即边或节点)表征,既利用了输入图指定的关系信息,又允许通过自注意力扩展这一先验结构。PyTorch最新发布的Flex Attention功能支持通过图结构驱动的丰富掩码操作符进行扩展,同时利用精确高效注意力机制的进展——这些技术突破正是本项工作的基础。

Our comprehensive evaluation shows that the proposed approach consistently outperforms strong message passing baselines and recently proposed transformer-based approaches for learning on graphs. The takeaway from our extensive study is that the proposed approach is well suited for being a simple and yet extremely effective starting point for learning on graphs. Moreover, the approach has favourable computational complexity and scales better than strong GNNs and some graph transformers. The approach also does well in transfer learning settings, possibly paving the way for more research on foundational models for drug discovery, where the problem arises frequently in property prediction for expensive high-fidelity experiments.

我们的全面评估表明,所提出的方法始终优于强大的消息传递基线方法和近期提出的基于Transformer的图学习方法。通过广泛研究得出的结论是,该方法非常适合作为图学习中简单而极其有效的起点。此外,该方法具有良好的计算复杂度,比强大的图神经网络(GNN)和一些图Transformer更具扩展性。该方法在迁移学习场景中也表现良好,可能为药物发现领域的基础模型研究开辟新途径——在该领域中,高成本高保真实验的属性预测经常遇到此类问题。

References

参考文献

SI 1 Additional benchmarking results

SI 1 补充基准测试结果

Supplementary Table 1: The root mean squared error on QM9 (RMSE is the standard for quantum mechanics) and the $\mathrm{{R}^{2}}$ for DOCKSTRING (DOCK) and MOLECULE NET (MN), presented as mean $\pm$ standard deviation over 5 runs, and including GCN and GIN. The mean absolute error (MAE) is reported for PCQM4MV2 over a single run due to the size of the dataset and the field standards. The lowest MAEs and RMSEs and highest $\mathrm{R^{2}}$ values are highlighted in bold. ooM denotes out-of-memory errors.

附表 1: QM9 的均方根误差 (RMSE 是量子力学的标准) 以及 DOCKSTRING (DOCK) 和 MOLECULE NET (MN) 的 $\mathrm{{R}^{2}}$ 值, 以 5 次运行的平均值 $\pm$ 标准差表示, 包括 GCN 和 GIN。由于数据集规模和领域标准, PCQM4MV2 的均绝对误差 (MAE) 仅报告单次运行结果。最低 MAE 和 RMSE 以及最高 $\mathrm{R^{2}}$ 值以粗体标出。ooM 表示内存不足错误。

Target GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
μ 0.629 ± 0.005 0.550 ± 0.011 0.552 ± 0.010 0.552 ± 0.011 0.545 ± 0.008 0.535 ± 0.008 0.627 ± 0.014 0.755 ± 0.025 0.945 ± 0.174 0.564 ± 0.004
0.525 ± 0.051 0.438 ± 0.020 0.445 ± 0.025 0.479 ± 0.020 0.460 ± 0.018 0.476 ± 0.070 0.404 ± 0.017 0.447 ± 0.014 0.605 ± 0.130 0.398 ± 0.003
EHOMO 0.128 ± 0.005 0.104 ± 0.001 0.106 ± 0.002 0.107 ± 0.002 0.104 ± 0.001 0.099 ± 0.001 0.107 ± 0.002 0.123 ± 0.003 0.113 ± 0.003 0.103 ± 0.003
0.136 ± 0.002 0.120 ± 0.004 0.123 ± 0.004 0.124 ± 0.003 0.127 ± 0.001 0.109 ± 0.001 0.112 ± 0.001 0.130 ± 0.003 0.108 ± 0.002 0.114 ± 0.001
ELUMO △E 0.184 ± 0.006 0.164 ± 0.004 0.166 ± 0.004 0.163 ± 0.006 0.163 ± 0.003 0.139 ± 0.001 0.163 ± 0.004 0.177 ± 0.006 0.154 ± 0.006 0.152 ± 0.001
(R²) 33.913 ± 0.922 29.925 ± 0.302 29.870 ± 0.343 30.911 ± 0.152 30.149 ± 0.360 28.503 ± 0.444 29.628 ± 0.351 31.540 ± 0.422 30.421 ± 1.192 28.328 ± 0.321
0.035 ± 0.001 0.032 ± 0.001 0.034 ± 0.002 0.033 ± 0.001 0.032 ± 0.001 0.032 ± 0.003 0.059 ± 0.031 0.030 ± 0.001 0.033 ± 0.006 0.026 ± 0.001
ZPVE 30.971 ± 6.472 42.024 ± 17.272 29.817 ± 6.322 27.823 ± 3.981 25.405 ± 3.617 31.644 ± 6.493 24.600 ± 4.699 15.477 ± 4.675 14.496 ± 4.342 4.777 ± 0.671
Uo 25.276 ± 4.471 25.098 ± 5.972 24.741 ± 6.311 30.914 ± 2.620 24.919 ± 4.183 25.038 ± 5.479 15.546 ± 8.228 12.449 ± 1.864 15.820 ± 4.262 6.799 ± 2.809
U H 30.924 ± 5.369 24.746 ± 2.314 34.019 ± 5.663 28.924 ± 4.108 23.422 ± 2.767 27.338 ± 8.167 19.006 ± 4.645 11.418 ± 3.716 12.923 ± 4.287 6.018 ± 1.324
6WO G 26.138 ± 10.078 26.942 ± 2.910 25.412 ± 4.863 31.494 ± 1.857 23.240 ± 2.535 22.308 ± 3.453 31.198 ± 4.614 29.721 ± 8.790 13.081 ± 4.806 6.104 ± 1.334
CV 0.221 ± 0.014 0.187 ± 0.008 0.195 ± 0.012 0.191 ± 0.005 0.190 ± 0.002 0.179 ± 0.009 0.177 ± 0.023 0.180 ± 0.004 0.190 ± 0.027 0.158 ± 0.001
UATOM 0.371 ± 0.017 0.391 ± 0.035 0.378 ± 0.022 0.392 ± 0.035 0.384 ± 0.021 0.353 ± 0.018 0.289 ± 0.004 0.304 ± 0.021 0.330 ± 0.051 0.241 ± 0.006
UATOM 0.390 ± 0.029 0.377 ± 0.029 0.404 ± 0.044 0.483 ± 0.072 0.397 ± 0.036 0.361 ± 0.020 0.302 ± 0.014 0.302 ± 0.004 0.334 ± 0.058 0.243 ± 0.005
HATOM 0.396 ± 0.025 0.370 ± 0.026 0.376 ± 0.038 0.383 ± 0.021 0.406 ± 0.062 0.373 ± 0.029 0.301 ± 0.012 0.312 ± 0.014 0.363 ± 0.089 0.245 ± 0.004
GATOM 0.368 ± 0.008 0.385 ± 0.021 0.372 ± 0.036 0.342 ± 0.018 0.329 ± 0.009 0.314 ± 0.021 0.261 ± 0.005 0.273 ± 0.006 0.360 ± 0.075 0.225 ± 0.015
A 0.979 ± 0.049 1.317 ± 0.386 0.904 ± 0.059 0.972 ± 0.122 1.078 ± 0.161 1.007 ± 0.101 64.877 ± 29.771 3.823 ± 2.052 1.422 ± 0.437 0.746 ± 0.106
B 0.295 ± 0.004 0.196 ± 0.047 0.194 ± 0.050 0.211 ± 0.038 0.264 ± 0.010 0.256 ± 0.024 0.102 ± 0.028 0.109 ± 0.013 0.158 ± 0.052 0.079 ± 0.011
C 0.284 ± 0.001 0.167 ± 0.062 0.269 ± 0.016 0.266 ± 0.006 0.276 ± 0.004 0.277 ± 0.012 0.115 ± 0.037 0.097 ± 0.025 0.124 ± 0.046 0.050 ± 0.012
ESR2 0.642 ± 0.003 0.668 ± 0.003 0.675 ± 0.003 0.666 ± 0.002 0.655 ± 0.004 0.696 ± 0.002 OOM 0.641 ± 0.008 0.676 ± 0.002
F2 0.878 ± 0.001 0.887 ± 0.002 0.886 ± 0.001 0.886 ± 0.001 0.885 ± 0.002 0.891 ± 0.002 OOM 0.872 ± 0.006 0.879 ± 0.004 0.697 ± 0.001 0.891 ± 0.000
DOCK KIT 0.814 ± 0.002 0.833 ± 0.001 0.835 ± 0.002 0.833 ± 0.000 0.826 ± 0.001 0.843 ± 0.001 OOM 0.800 ± 0.009 0.832 ± 0.001 0.841 ± 0.001
PARP1 0.912 ± 0.001 0.922 ± 0.001 0.920 ± 0.002 0.921 ± 0.001 0.919 ± 0.001 0.924 ± 0.001 OOM 0.907 ± 0.005 0.915 ± 0.005 0.925 ± 0.000
PGR 0.658 ± 0.004 0.696 ± 0.001 0.702 ± 0.002 0.681 ± 0.005 0.666 ± 0.006 0.717 ± 0.003 OOM 0.684 ± 0.010 0.703 ± 0.009 0.725 ± 0.003
FSOLV 0.957 ± 0.008 0.964 ± 0.007 0.972 ± 0.005 0.959 ± 0.009 0.970 ± 0.007 0.951 ± 0.008 0.927 ± 0.005 0.930 ± 0.016
MN LIPO 0.800 ± 0.007 0.819 ± 0.006 0.809 ± 0.007 0.820 ± 0.012 0.821 ± 0.008 0.830 ± 0.006 0.607 ± 0.043 0.545 ± 0.022 0.861 ± 0.032 0.790 ± 0.004 0.977 ± 0.001 0.809 ± 0.007
ESOL 0.936 ± 0.005 0.938 ± 0.010 0.935 ± 0.011 0.930 ± 0.006 0.928 ± 0.005 0.942 ± 0.006 0.908 ± 0.018 0.892 ± 0.032 0.911 ± 0.003 0.944 ± 0.002

ementary Table 2: The mean absolute error (MAE) on ZINC, presented as mean $\pm$ standard deviation over 5 runs. The lowest values are highlighted in bolc

补充表 2: ZINC 上的平均绝对误差 (MAE) ,以 5 次运行的平均值 $\pm$ 标准差表示。最低值用粗体标出。

Dataset (↓) GCN GIN GAT GATv2 PNA Graphormer TokenGT GPS ESA ESA (PE)
ZINC 0.152 ± 0.02 0.068 ± 0.00 0.078 ± 0.01 0.079 ± 0.00 0.057 ± 0.01 0.036 ± 0.00 0.047 ± 0.01 0.024 ± 0.01 0.027 ± 0.00 0.017 ± 0.00

up ple ment ary Table 3: Transfer learning performance (RMSE) on QM9 for HOMO and LUMO, presented as mean $\pm$ standard deviation over 5 different runs, including GCN nd GIN. All models use the 3D atomic coordinates and atom types as inputs and no other node or edge features. Strat? stands for strategy and specifies the type of learning: W only (no transfer learning), inductive, or trans duct ive. The lowest values are highlighted in bold.

表 3: QM9 数据集上 HOMO 和 LUMO 的迁移学习性能 (RMSE) , 以 5 次独立实验的平均值 $\pm$ 标准差表示, 包含 GCN 和 GIN 模型。所有模型均使用 3D 原子坐标和原子类型作为输入, 不包含其他节点或边特征。Strat? 表示策略 (strategy) , 指定学习类型: W only (无迁移学习) 、inductive (归纳式) 或 transductive (直推式) 。最低值以粗体标出。

任务 策略 GCN GIN DropGIN GAT GATv2 PNA Grph. TokenGT GPS
GW 0.171 ± 0.004 0.162 ± 0.002 0.162 ± 0.002 0.159 ± 0.003 0.157 ± 0.002 0.151 ± 0.001 0.179 ± 0.009 0.200 ± 0.008 0.162 ± 0.002 0.152 ± 0.003
HOMO Ind. 0.143 ± 0.004 0.138 ± 0.001 0.136 ± 0.000 0.131 ± 0.001 0.133 ± 0.003 0.132 ± 0.002 0.134 ± 0.000 0.156 ± 0.000 0.151 ± 0.002 0.131 ± 0.000
Trans. 0.131 ± 0.001 0.125 ± 0.001 0.126 ± 0.002 0.123 ± 0.001 0.124 ± 0.001 0.121 ± 0.001 0.125 ± 0.001 0.137 ± 0.000 0.147 ± 0.002 0.119 ± 0.000
GW 0.181 ± 0.002 0.180 ± 0.002 0.180 ± 0.002 0.181 ± 0.002 0.178 ± 0.002 0.174 ± 0.190 ± 0.006 0.204 ± 0.006
LUMO Ind. 0.161 ± 0.001 0.159 ± 0.001 0.159 ± 0.001 0.156 ± 0.001 0.157 ± 0.001 0.004 0.156 ± 0.002 0.151 ± 0.001 0.178 ± 0.002 0.174 ± 0.001
Trans. 0.159 ± 0.002 0.155 ± 0.001 0.157 ± 0.001 0.153 ± 0.001 0.153 ± 0.001 0.153 ± 0.001 0.147 ± 0.001 0.165 ± 0.000 0.156 ± 0.000 0.167 ± 0.001 0.169 ± 0.001 0.150 ± 0.001 0.146 ± 0.000

Supplementary Table 4: Matthews correlation coeffcient (MCC) for graph-level molecular classification tasks - Molecule Net and National Cancer Institute (Nci), presented as mean $\pm$ standard deviation over 5 different runs, and including GCN and GIN. OoM denotes out-of-memory errors. The highest mean values are highlighted in bold.

表 4: 图级分子分类任务 (Molecule Net 和美国国家癌症研究所 Nci) 的马修斯相关系数 (MCC) - 以5次运行的平均值±标准差呈现,包含GCN和GIN。OoM表示内存不足错误。最高平均值以粗体标出。

Data (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS
MN BBBP 0.674±0.034 0.704±0.028 0.685±0.017 0.744±0.012 0.728±0.032 0.731±0.028 0.552±0.012 0.578±0.065 0.705±0.044 0.835±0.014
BACE 0.631±0.028 0.646±0.013 0.654±0.034 0.632±0.018 0.645±0.026 0.638±0.017 0.522±0.020 0.578±0.033 0.618±0.032 0.721±0.019
0 HIV 0.448±0.035 0.408±0.060 0.458±0.028 0.421±0.061 0.337±0.059 0.417±0.045 OOM 0.455±0.017 0.247±0.211 0.533±0.012
NCI1 0.682±0.013 0.694±0.017 0.686±0.027 0.701±0.018 0.646±0.029 0.697±0.025 0.540±0.025 0.532±0.034 0.697±0.027 0.755±0.012
NC1109 0.665±0.024 0.684±0.015 0.681±0.021 0.658±0.008 0.664±0.015 0.670±0.018 0.504±0.022 0.453±0.029 0.623±0.014 0.700±0.010

Supplementary Table 5: Accuracy for graph-level molecular classification tasks - Molecule Net and National Cancer Institute (Nci), presented as mean $\pm$ standard deviation over 5 different runs, and including GCN and GIN. OoM denotes out-of-memory errors. The highest mean values are highlighted in bold.

表 5: 图级别分子分类任务准确率 - MoleculeNet和美国国家癌症研究所(NCI)数据集,结果以5次运行的平均值±标准差表示,包含GCN和GIN方法。OoM表示内存不足错误。最高均值以加粗显示。

Data (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MN BBBP 0.871 ± 0.012 0.882 ± 0.010 0.875 ± 0.006 0.898 ± 0.004 0.892 ± 0.012 0.893 ± 0.010 0.826 ± 0.005 0.832 ± 0.022 0.883 ± 0.017 0.932 ± 0.007
BACE 0.813 ± 0.014 0.820 ± 0.007 0.824 ± 0.017 0.813 ± 0.009 0.820 ± 0.012 0.817 ± 0.009 0.758 ± 0.011 0.786 ± 0.015 0.804 ± 0.017 0.858 ± 0.010
HIV 0.974 ± 0.001 0.973 ± 0.001 0.974 ± 0.001 0.973 ± 0.002 0.971 ± 0.001 0.973 ± 0.001 OOM 0.973 ± 0.001 0.971 ± 0.003 0.976 ± 0.001
NCI NCI1 0.842 ± 0.006 0.848 ± 0.008 0.843 ± 0.014 0.851 ± 0.010 0.824 ± 0.015 0.850 ± 0.012 0.770 ± 0.012 0.767 ± 0.018 0.850 ± 0.014 0.878 ± 0.006
NC1109 0.831 ± 0.011 0.842 ± 0.007 0.840 ± 0.010 0.826 ± 0.005 0.831 ± 0.007 0.834 ± 0.009 0.749 ± 0.011 0.721 ± 0.017 0.809 ± 0.006 0.850 ± 0.005

Supplementary Table 6: Matthews correlation coefficient (MCC) for graph-level classification tasks from various domains, presented as mean $\pm$ standard deviation over 5 diferent runs, and including GCN and GIN. ooM denotes out-of-memory errors, and N/A that the model is unavailable (e.g., node/edge features are not integers, which are required for Graphormer and TokenGT). The highest mean values are highlighted in bold.

表 6: 不同领域图级分类任务的马修斯相关系数 (MCC),以5次运行的平均值±标准差表示,包括GCN和GIN。ooM表示内存不足错误,N/A表示模型不可用(例如节点/边特征不是整数,而Graphormer和TokenGT需要整数特征)。最高平均值以粗体标出。

Dataset (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MALNETTINY 0.896 ± 0.006 0.903 ± 0.005 0.902 ± 0.006 0.899 ± 0.006 0.902 ± 0.007 0.915 ± 0.008 OOM 0.777 ± 0.008 0.795 ± 0.011 0.931 ± 0.001
MNIST 0.957 ± 0.002 0.965 ± 0.004 0.970 ± 0.001 0.972 ± 0.002 0.979 ± 0.002 0.978 ± 0.003 N/A N/A 0.980 ± 0.001 0.986 ± 0.000
CIFAR10 0.621 ± 0.003 0.614 ± 0.006 0.614 ± 0.009 0.659 ± 0.009 0.662 ± 0.011 0.686 ± 0.005 N/A N/A 0.708 ± 0.005 0.727 ± 0.003
ENZYMES 0.695 ± 0.048 0.632 ± 0.046 0.576 ± 0.039 0.748 ± 0.020 0.744 ± 0.023 0.684 ± 0.034 N/A N/A 0.734 ± 0.045 0.751 ± 0.009
PROTEINS 0.419 ± 0.037 0.421 ± 0.036 0.459 ± 0.005 0.463 ± 0.036 0.490 ± 0.042 0.467 ± 0.067 N/A N/A 0.443 ± 0.022 0.589 ± 0.017
DD 0.546 ± 0.058 0.539 ± 0.032 0.537 ± 0.070 0.465 ± 0.049 0.530 ± 0.034 0.559 ± 0.080 OOM 0.459 ± 0.049 0.605 ± 0.041 0.652 ± 0.030
SYNTH 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 N/A N/A 1.000 ± 0.000 1.000 ± 0.000
SYNTH-N 0.699 ± 0.029 0.910 ± 0.030 0.975 ± 0.051 0.761 ± 0.067 0.909 ± 0.067 1.000 ± 0.000 N/A N/A 1.000 ± 0.000 1.000 ± 0.000
SYNTHIE 0.935 ± 0.054 0.930 ± 0.045 0.942 ± 0.024 0.701 ± 0.045 0.798 ± 0.038 0.879 ± 0.058 N/A N/A 0.951 ± 0.019 0.947 ± 0.016
IMDB-B 0.603 ± 0.056 0.537 ± 0.199 0.608 ± 0.061 0.688 ± 0.037 0.600 ± 0.049 0.563 ± 0.067 0.565 ± 0.045 0.606 ± 0.052 0.598 ± 0.045 0.738 ± 0.026
IMDB-M 0.216 ± 0.044 0.119 ± 0.079 0.117 ± 0.099 0.203 ± 0.052 0.202 ± 0.035 0.034 ± 0.068 0.222 ± 0.024 0.202 ± 0.023 0.232 ± 0.021 0.247 ± 0.034
TWITCH E. 0.386 ± 0.006 0.358 ± 0.023 0.379 ± 0.011 0.373 ± 0.011 0.371 ± 0.011 0.078 ± 0.159 0.387 ± 0.003 0.393 ± 0.001 0.395 ± 0.001 0.398 ± 0.000
REDDIT THR. 0.556 ± 0.007 0.556 ± 0.011 0.556 ± 0.007 0.533 ± 0.021 0.536 ± 0.023 0.113 ± 0.227 0.567 ± 0.003 0.564 ± 0.001 0.568 ± 0.003 0.568 ± 0.002

Supplementary Table 7: Accuracy for graph-level classification tasks from various domains, presented as mean $\pm$ standard deviation over 5 different runs, and including GCN and GIN. ooM denotes out-of-memory errors, and N/A that the model is unavailable (e.g., node/edge features are not integers, which are required for Graphormer and TokenGT). The highest mean values are highlighted in bold.

附表 7: 不同领域图级分类任务的准确率,以5次运行的平均值 $\pm$ 标准差表示,包含 GCN 和 GIN。ooM 表示内存不足错误,N/A 表示模型不可用 (例如 Graphormer 和 TokenGT 所需的节点/边特征不是整数)。最高平均值以粗体标出。

Dataset (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
MALNETTINY 0.916 ± 0.005 0.922 ± 0.004 0.921 ± 0.005 0.918 ± 0.005 0.921 ± 0.006 0.931 ± 0.006 OOM 0.820 ± 0.007 0.835 ± 0.009 0.944 ± 0.001
MNIST 0.961 ± 0.001 0.969 ± 0.004 0.973 ± 0.001 0.975 ± 0.002 0.981 ± 0.001 0.980 ± 0.003 N/A N/A 0.982 ± 0.001 0.988 ± 0.000
CIFAR10 0.659 ± 0.003 0.652 ± 0.005 0.652 ± 0.008 0.693 ± 0.008 0.695 ± 0.010 0.717 ± 0.005 N/A N/A 0.737 ± 0.005 0.754 ± 0.002
ENZYMES 0.735 ± 0.039 0.683 ± 0.037 0.651 ± 0.037 0.786 ± 0.014 0.780 ± 0.019 0.730 ± 0.022 N/A N/A 0.777 ± 0.039 0.794 ± 0.015
PROTEINS 0.755 ± 0.015 0.755 ± 0.017 0.768 ± 0.000 0.768 ± 0.015 0.777 ± 0.020 0.777 ± 0.029 N/A N/A 0.768 ± 0.011 0.827 ± 0.007
DD 0.782 ± 0.031 0.773 ± 0.020 0.782 ± 0.033 0.731 ± 0.031 0.760 ± 0.020 0.790 ± 0.039 OOM 0.739 ± 0.030 0.808 ± 0.017 0.835 ± 0.016
SYNTH 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 1.000 ± 0.000 N/A N/A 1.000 ± 0.000 1.000 ± 0.000
SYNTH N. 0.847 ± 0.016 0.953 ± 0.016 0.987 ± 0.027 0.880 ± 0.034 0.953 ± 0.034 1.000 ± 0.000 N/A N/A 1.000 ± 0.000 1.000 ± 0.000
SYNTHIE 0.956 ± 0.039 0.955 ± 0.033 0.963 ± 0.018 0.776 ± 0.037 0.855 ± 0.031 0.918 ± 0.040 N/A N/A 0.958 ± 0.014 0.963 ± 0.012
IMDB-B 0.802 ± 0.028 0.766 ± 0.097 0.804 ± 0.030 0.842 ± 0.018 0.800 ± 0.024 0.780 ± 0.034 0.780 ± 0.023 0.802 ± 0.026 0.794 ± 0.024 0.868 ± 0.013
IMDB-M 0.476 ± 0.030 0.411 ± 0.052 0.410 ± 0.064 0.470 ± 0.033 0.469 ± 0.024 0.356 ± 0.046 0.484 ± 0.015 0.470 ± 0.015 0.476 ± 0.013 0.487 ± 0.011
TWITCHE. 0.697 ± 0.003 0.682 ± 0.012 0.694 ± 0.005 0.690 ± 0.005 0.690 ± 0.005 0.506 ± 0.098 0.698 ± 0.001 0.701 ± 0.001 0.702 ± 0.000 0.703 ± 0.000
REDDIT THR. 0.778 ± 0.003 0.776 ± 0.005 0.777 ± 0.003 0.765 ± 0.010 0.767 ± 0.011 0.545 ± 0.118 0.782 ± 0.001 0.780 ± 0.001 0.783 ± 0.001 0.782 ± 0.001

Supplementary Table 8: Matthews correlation coeficient (MCC) for 11 node-level classification tasks, presented as mean $\pm$ standard deviation over 5 different runs, including GIN and DropGIN. The number of nodes for the shortest path (sP) benchmarks is given in parentheses (based on randomly-generated “infected’ Erd?s-Renyi (ER) graphs; SI 8 for details). Additional he t ero phil y results are provided in Supplementary Table 10. The highest mean values are highlighted in bold.

附表 8: 11个节点级分类任务的马修斯相关系数 (MCC),以5次不同运行的平均值 $\pm$ 标准差表示,包括GIN和DropGIN。最短路径(sP)基准的节点数在括号中给出(基于随机生成的“感染”Erd?s-Renyi (ER)图;详见SI 8)。其他异质性结果见附表10。最高平均值以粗体标出。

数据集 (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
PPI 0.979 ± 0.002 0.905 ± 0.066 0.785 ± 0.141 0.990±0.000 0.982 ± 0.005 0.990 ± 0.000 N/A N/A N/A 0.989 ± 0.001
CITESEER 0.608± 0.008 0.587 7±0.010 0.324 ± 0.017 0.587 士 0.030 0.613 0.006 0.511 1±0.033 OOM 0.384 ± 0.022 0.538 ± 0.012 0.632 0.005
CORA 0.767 ± 0.008 0.727 ± 0.010 0.490 ± 0.033 0.748 士 0.014 0.727 ± 0.010 0.637 ± 0.029 OOM 0.366 ± 0.185 0.643 ± 0.039 0.768 ± 0.005
ROMAN EMP. 0.470 ± 0.004 0.791 ± 0.003 0.796 ± 0.002 0.741 士 0.010 0.763 ± 0.004 0.855 ± 0.002 N/A N/A 0.837 ± 0.014 0.869 ±0.002
AMAZON R. 0.179 ± 0.003 0.262 ± 0.015 0.228 ± 0.018 0.256 士 0.009 0.253 ± 0.013 0.206 ± 0.018 N/A N/A 0.114 ± 0.139 0.336 ±0.006
MINESWEEPER 0.303 ± 0.002 0.563 ± 0.028 0.455 ± 0.162 0.477 ± 0.018 0.512 ± 0.010 0.624 ± 0.043 N/A N/A 0.564 ± 0.011 0.688 ± 0.001
TOLOKERS 0.299 ± 0.011 0.299 ± 0.065 0.340 ± 0.012 0.384± 0.009 0.386 ± 0.005 0.350 ± 0.048 N/A N/A 0.351 ± 0.021 0.427± 0.004
SQUIRREL 0.197 ± 0.011 0.194 ± 0.026 0.230 ± 0.019 0.237± 0.017 0.238 ± 0.013 0.224 ± 0.011 0.104 ± 0.085 0.175 ± 0.017 0.228 ± 0.021 0.289 ± 0.006
CHAMELEON 0.324 ± 0.006 0.272 ± 0.024 0.287 ± 0.026 0.240± 0.056 0.281 ± 0.030 0.267 ± 0.030 0.253 ± 0.038 0.260 ± 0.045 0.298 ± 0.081 0.387± 0.018
ER (15K) ER (30K) 0.216 ± 0.016 0.094 ± 0.035 0.387 ± 0.067 0.330 ±0.021 0.244 ± 0.049 0.292 ± 0.022 0.315 土 0.001 0.102 ± 0.064 0.316 ± 0.000 0.102 ± 0.058 0.543 ± 0.086 0.423 ± 0.049 OOM OOM 0.065 ± 0.000 OOM 0.178 ± 0.040 OOM 0.915 ± 0.007 0.872 ± 0.008

Supplementary Table 9: Accuracy for 11 node-level classification tasks, presented as mean $\pm$ standard deviation over 5 different runs, including GIN and DropGIN. The number of nodes for the shortest path (sP) benchmarks is given in parentheses (based on randomly-generated “infected’ Erd6s-Renyi (ER) graphs; SI 8 for details). Additional he t ero phil y results are provided in Supplementary Table 10. The highest mean values are highlighted in bold.

附表 9: 11个节点级分类任务的准确率,以5次运行的平均值$\pm$标准差表示,包括GIN和DropGIN。最短路径(sP)基准的节点数在括号中给出(基于随机生成的"感染"Erd6s-Renyi(ER)图;详见SI 8)。补充表10提供了额外的异质性结果。最高平均值以粗体标出。

数据集 (↑) GCN GIN DropGIN GAT GATv2 PNA Graphormer TokenGT GPS ESA
PPI 0.991 ± 0.001 0.960 ± 0.027 0.913 ± 0.056 0.996 ± 0.000 0.992 ± 0.002 0.996 ± 6.480 N/A N/A 0.995 ± 0.000
CITESEER 0.648 ± 0.008 0.626 ± 0.008 0.426 ± 0.016 0.625 ± 0.018 0.649 ± 0.005 0.557 ± 0.027 OOM 0.614 ± 0.011 0.651 ± 0.008
LIO CORA 0.822 ± 0.005 0.786 ± 0.012 0.595 ± 0.029 0.803 ± 0.014 0.785 ± 0.007 0.716 ± 0.019 OOM 0.820 ± 0.004 0.820 ± 0.004
ROMANEMP. 0.462 ± 0.002 0.755 ± 0.005 0.764 ± 0.003 0.703 ± 0.016 0.721 ± 0.009 0.831 ± 0.004 N/A 0.850 ± 0.007 0.850 ± 0.007
AMAZONR. 0.284 ± 0.003 0.363 ± 0.013 0.324 ± 0.013 0.358 ± 0.013 0.352 ± 0.017 0.311 ± 0.016 N/A 0.445 ± 0.016 0.445 ± 0.016
MINESWEEPER 0.605 ± 0.001 0.764 ± 0.022 0.702 ± 0.092 0.713 ± 0.016 0.729 ± 0.007 0.801 ± 0.033 N/A 0.852 ± 0.003 0.852 ± 0.003
TOLOKERS 0.595 ± 0.006 0.601 ± 0.034 0.628 ± 0.011 0.649 ± 0.007 0.642 ± 0.005 0.649 ± 0.022 N/A 0.714 ± 0.011 0.714 ± 0.011
SQUIRREL 0.318 ± 0.010 0.318 ± 0.025 0.346 ± 0.024 0.351 ± 0.010 0.353 ± 0.013 0.338 ± 0.018 0.294 ± 0.017 0.409 ± 0.008 0.409 ± 0.008
CHAMELEON 0.436 ± 0.004 0.399 ± 0.020 0.414 ± 0.019 0.363 ± 0.042 0.400 ± 0.023 0.395 ± 0.021 0.381 ± 0.034 0.478 ± 0.015 0.478 ± 0.015
ER (15K) 0.227 ± 0.020 0.351 ± 0.116 0.154 ± 0.086 0.361 ± 0.003 0.364 ± 0.000 0.528 ± 0.073 OOM 0.886 ± 0.011 0.886 ± 0.011
INF ER (30K) 0.159 ± 0.020 0.189 ± 0.103 0.124 ± 0.078 0.246 ± 0.130 0.237 ± 0.121 0.442 ± 0.087 OOM 0.760 ± 0.040 0.760 ± 0.040

Supplementary Table 10: Matthews correlation coefficient (MCC) for the he t ero philo us node-level classification tasks and two additional baselines, presented as mean $\pm$ standard deviation over 5 different runs.

补充表 10: 异质节点级分类任务的马修斯相关系数 (MCC) 及两个额外基线方法的结果,以5次运行的平均值 $\pm$ 标准差表示。

数据集 (↑) GraphSAGE GT
HETEROPHILY ROMANEMPIRE 0.83±0.00 0.84±0.00
AMAZONRATINGS 0.36±0.00 0.34±0.01
MINESWEEPER 0.65±0.01 0.57±0.04
TOLOKERS 0.23±0.02 0.38±0.01
SQUIRRELFILTERED 0.14±0.03 0.18±0.01
CHAMELEONFILTERED 0.27±0.03 0.25±0.04

SI 2 Additional time and memory results

SI 2 额外时间和内存结果

Supplementary Figure 1: The elapsed time for training for a single epoch on the dockstring dataset (in seconds), and the maximum allocated memory during this training epoch (GB).

Supplementary Figure 1: 在dockstring数据集上单轮训练耗时(单位:秒)及该训练轮次的最大内存分配量(单位:GB)。

SI 3 Sourcing and licensing

SI 3 数据来源与授权

Most datasets are sourced from the PyTorch Geometric library (MIT license). Datasets with different sources include: dockstring (Apache 2.0, https://github.com/dockstring/dockstring), the he t ero phil y datasets (MIT license, https://github.com/yandex-research/he t ero philo us-graphs), the peptide datasets from the Long Range Graph Benchmark project (MIT license, https://github.com/vijay d wive di 75/lrgb), the GW frontier orbital energies for qm9 (CC BY 4.0 license, https://doi.org/10.6084/m9.figshare.21610077.v1), and the Open Catalyst Project (CC BY 4.0 license for the datasets, MIT license for the Python package, https://github.com/Open-Catalyst-Project/Open-Catalyst-Dataset).

大多数数据集来源于PyTorch Geometric库(MIT许可证)。不同来源的数据集包括:dockstring(Apache 2.0, https://github.com/dockstring/dockstring)、异质性数据集(MIT许可证, https://github.com/yandex-research/heterophilous-graphs)、长程图基准项目中的肽类数据集(MIT许可证, https://github.com/vijaydwivedi75/lrgb)、qm9的GW前沿轨道能(CC BY 4.0许可证, https://doi.org/10.6084/m9.figshare.21610077.v1)以及Open Catalyst项目(数据集使用CC BY 4.0许可证,Python包使用MIT许可证, https://github.com/Open-Catalyst-Project/Open-Catalyst-Dataset)。

All GNN implementations used in this work are sourced from the PyTorch Geometric library (MIT license). The Graphormer and TokenGT implementations are sourced from the Hugging face project (Apache 2.0 license). The Graphormer implementation used for the 3D modelling task is sourced from the official GitHub repository (MIT license, https://github.com/microsoft/Graphormer). GraphGPS also uses the MIT license. PyTorch uses the BSD-3 license.

本研究中使用的所有GNN实现均来自PyTorch Geometric库(MIT许可证)。Graphormer和TokenGT实现源自Hugging Face项目(Apache 2.0许可证)。用于3D建模任务的Graphormer实现来自官方GitHub仓库(MIT许可证,https://github.com/microsoft/Graphormer)。GraphGPS同样采用MIT许可证。PyTorch使用BSD-3许可证。

SI 4 Transfer learning setup

SI 4 迁移学习设置

The transfer learning setup consists of randomly selected training, validation, and test sets of 25K, 5K, and respectively 10K molecules with GW calculations (from the total 133,885). This setup mimics the low amounts of high quality/fidelity data available in drug discovery and quantum simulations projects. In the trans duct ive case, the entire dataset with DFT targets is used for pre-training (including the 10K test set compounds, but only with DFT-level measurements), while in the inductive setting the 10K set is completely excluded. Here, we perform transfer learning by pre-training a model on the DFT target for a fixed number of epochs (150) and then fine-tuning it on the subset of 25K GW calculations. In the trans duct ive case, pre-training occurs on the full set of 133K DFT calculations, while in the inductive case the DFT test set values are removed (note that the evaluation is done on the test set GW measurements).

迁移学习的设置包括随机选取的25K、5K和10K分子(来自总计133,885个分子)分别作为训练集、验证集和测试集,这些分子均带有GW计算数据。这一设置模拟了药物发现和量子模拟项目中可用的少量高质量/高保真数据。在传导式(transductive)情况下,整个带有DFT目标的数据集用于预训练(包括10K测试集化合物,但仅使用DFT级别测量),而在归纳式(inductive)设置中,10K测试集被完全排除。此处,我们通过在DFT目标上预训练模型固定周期数(150轮),然后在25K GW计算子集上进行微调来实现迁移学习。在传导式情况下,预训练基于全部133K DFT计算数据,而在归纳式情况下则移除了DFT测试集值(注意评估是在测试集的GW测量上进行的)。

SI 5 Limitations

SI 5 局限性

In terms of limitations, we highlight that the available libraries are not optimised for masking or custom attention patterns. This is most evident for very dense graphs (tens of thousands of edges or more). Memory efficient and Flash attention are available natively in PyTorch [46] starting from version 2.0, as well as in the xFormers library [47]. More specifically, we have tested at least 5 different implementations of ESA: (1) leveraging the Multi head Attention module from PyTorch, (2) leveraging the Multi Head Dispatch module from xFormers, (3) a manual implementation of multihead attention, relying on PyTorch’s scaled dot product attention function, (4) a manual implementation of multihead attention, relying on xFormers’ memory efficient attention, and (5) a naive implementation. Options (1) - (4) can all make use of efficient and fast implementations. However, we have observed performance differences between the 4 implementations, as well as compared to a naive implementation. This behaviour is likely due to the different low-level kernel implementations. Moreover, Flash attention does not currently support custom attention masks as there is little interest for such functionality from a language modelling perspective.

在局限性方面,我们注意到现有库并未针对掩码或自定义注意力模式进行优化。这在处理极高密度图(数万条边以上)时尤为明显。从PyTorch [46] 2.0版本开始,原生支持内存高效注意力和Flash Attention,xFormers库[47]也提供相关功能。具体而言,我们测试了至少5种ESA实现方案:(1) 使用PyTorch的多头注意力模块,(2) 采用xFormers的多头调度模块,(3) 基于PyTorch缩放点积注意力函数的手动多头注意力实现,(4) 基于xFormers内存高效注意力的手动多头注意力实现,(5) 原始实现。方案(1)-(4)均可利用高效快速实现,但我们观察到四种实现之间存在性能差异,且与原始实现相比也有不同。这种现象可能源于底层内核实现的差异。此外,由于语言建模领域对此功能需求较低,Flash Attention目前暂不支持自定义注意力掩码。

Although the masks can be computed efficiently during training, all frameworks require the last two dimensions of the input mask tensor to be of shape $(L_{n},L_{n})$ for nodes or $(L_{e},L_{e})$ for edges, effectively squaring the number of nodes or edges. However, the mask tensors are sparse and a sparse tensor alternative could greatly reduce the memory consumption for large and dense graphs. Such an option exists for the PyTorch native attention, but it is currently broken.

虽然这些掩码在训练期间可以高效计算,但所有框架都要求输入掩码张量的最后两个维度必须为节点形状 $(L_{n},L_{n})$ 或边形状 $(L_{e},L_{e})$ ,这实际上使节点或边数量呈平方级增长。不过掩码张量具有稀疏性,采用稀疏张量方案能大幅降低大型稠密图的内存消耗。PyTorch原生注意力机制已提供该选项,但目前存在缺陷。

Another possible optimisation would be to use nested (ragged) tensors to represent graphs, since padding is currently necessary to ensure identical dimensions for attention. A prototype nested tensor attention is available in PyTorch; however, not all the required operations are supported and converting between normal and nested tensors is slow.

另一个可能的优化是使用嵌套(ragged)张量来表示图,因为目前需要填充以确保注意力机制的维度一致。PyTorch中提供了嵌套张量注意力的原型实现;然而,并非所有必需的操作都受支持,且在普通张量和嵌套张量之间转换速度较慢。

For all implementations, it is required that the mask tensor is repeated by the number of attention heads (e.g., 8 or 16). However, a notable bottleneck is encountered for the Multi head Attention and Multi Head Dispatch variants described above, which require that the repeats happen in the batch dimension, i.e., requiring 3D mask tensors of shape $(\boldsymbol{B}\times\boldsymbol{H},\boldsymbol{L},\boldsymbol{L})$ , where $H$ is the number of heads. The other two efficient implementations require a 4D mask instead, i.e., $(B,H,L,L)$ , where one can use PyTorch’s expand function instead of repeat. The expand alternative does not use any additional memory, while repeat requires $\times H$ memory. Note that it is not possible to reshape the 4D tensor created using expand without using additional memory.

在所有实现中,都需要将掩码张量按注意力头数(如8或16)进行重复。但上述多头注意力(Multi head Attention)和多头分派(Multi Head Dispatch)变体会遇到显著瓶颈,这些变体要求重复操作发生在批次维度,即需要形状为$(\boldsymbol{B}\times\boldsymbol{H},\boldsymbol{L},\boldsymbol{L})$的3D掩码张量,其中$H$表示头数。另外两种高效实现则需要4D掩码$(B,H,L,L)$,此时可使用PyTorch的expand函数替代repeat操作。expand方案不会占用额外内存,而repeat操作需要$\times H$的内存开销。需要注意的是,在不使用额外内存的情况下,无法对通过expand创建的4D张量进行重塑操作。

Finally, we noticed a limitation involving PyTorch’s nonzero() tensor method, which is required as part of the edge masking algorithm (Algorithm 1). This is covered in detail in SI 9. Currently, the nonzero() method fails for very dense graphs. A fix would require an update to 64-bit integer limits.

最后,我们注意到一个涉及PyTorch的nonzero()张量方法的限制,这是边缘掩码算法(Algorithm 1)的必要组成部分。SI 9对此进行了详细说明。目前,nonzero()方法在图形非常密集时会失效,修复需要更新到64位整数限制。

SI 6 Helper functions

SI 6 辅助函数

consecutive is a helper function that generates consecutive numbers starting from 0, with a length specified in its tensor argument as the difference between adjacent elements, and a second integer argument used for the last length computation, e.g., consecutive $([1,4,6],10)=[0,1,2,0,1,0,1,2,3]$ , and first unique index finds the first occurrence of each unique element in the tensor (sorted), e.g., first unique index $([3,2,3,4,2]).$ ) $=[1,0,3]$ . The implementations are available in our code base.

consecutive 是一个辅助函数,用于生成从0开始的连续数字,其张量参数指定相邻元素之间的差值作为长度,第二个整数参数用于最后的长度计算。例如,consecutive $([1,4,6],10)=[0,1,2,0,1,0,1,2,3]$,而 first unique index 则查找张量中每个唯一元素的首次出现位置(排序后),例如 first unique index $([3,2,3,4,2]).$ ) $=[1,0,3]$。具体实现可在我们的代码库中找到。

SI 7 Experimental setup

SI 7 实验设置

We follow a simple and universal experimental protocol to ensure that it is possible to compare the results of different methods and to evaluate a large number of datasets with high throughput. Below in SI 7.1 we described the grid search approach and the used search parameters. For the basic hyper parameters such as batch size and learning rate, we chose a number of reasonable hyper parameters and settings for all methods, regardless of their nature (GNN or attention-based). This includes the AdamW optimiser [68], more specifically the 8-bit version [69], learning rate (0.0001), batch size (128), mixed precision training with the bfloat16 tensor format, early stopping with a patience of 30 epochs (100 for the very small datasets such as freesolv), and gradient clipping (set to the default value of 0.5). Furthermore, we used a simple learning rate scheduler that halved the learning rate if no improvement was encountered for 15 epochs (half the early stopping patience). If these parameters led to out-of-memory errors, we attempted reducing the batch size by 2 until the error was fixed, or reducing the hidden dimension as a last resort.

我们遵循一个简单通用的实验协议,以确保能够比较不同方法的结果,并以高吞吐量评估大量数据集。下文SI 7.1中描述了网格搜索方法和使用的搜索参数。对于基础超参数(如批大小和学习率),我们为所有方法选择了一系列合理的超参数和设置,无论其性质如何(图神经网络或基于注意力机制)。这包括AdamW优化器[68](更具体地说是8位版本[69])、学习率(0.0001)、批大小(128)、采用bfloat16张量格式的混合精度训练、早停耐心值设为30个周期(对于freesolv等极小数据集设为100),以及梯度裁剪(默认值设为0.5)。此外,我们使用了一个简单的学习率调度器:如果连续15个周期(早停耐心值的一半)没有改进,则将学习率减半。如果这些参数导致内存不足错误,我们会尝试将批大小减半直至错误修复,或作为最后手段减小隐藏层维度。

For Graphormer and TokenGT, we leverage the hugging face [70] implementation. We have adapted the TokenGT implementation to use Flash attention, which was not originally supported. For Graphormer, this optimisation is not possible since Flash attention is not compatible with some operations required by Graphormer. However, we did optimise the data loading process compared to the original hugging face implementation, leading to lower RAM usage.

对于Graphormer和TokenGT,我们采用了hugging face [70]的实现方案。我们调整了TokenGT的实现以支持原本不兼容的Flash attention优化。而Graphormer由于部分必要操作与Flash attention不兼容,无法应用该优化。不过相比原版hugging face实现,我们优化了数据加载流程,从而降低了内存占用。

SI 7.1 Hyper-parameter tuning

SI 7.1 超参数调优

For a fair and comprehensive evaluation, we tune each algorithm for each dataset using a grid search approach and a selection of reasonable hyper parameters. These include the number of layers, the number of attention heads for GAT(v2) and graph transformers, dropout, and hidden dimensions. For graph-level GNNs, we evaluated configurations with 4 to 6 layers, hidden dimensions in ${128,256,512}$ , the number of GAT heads in ${8,16}$ , and GAT dropout in ${0,0.2}$ . For node-level GNNs, we adapted the search since these tasks are more sensitive to over fitting, and used a number of layers in ${1,2,4,6}$ , hidden dimensions in ${64,128,256}$ , the same GAT heads and dropout settings, and dropout after each GNN layer in ${0,0.2}$ . For Graphormer, TokenGT, and GraphGPS, we evaluate models with the number of layers in ${4,6,8,10}$ , the number of attention heads in ${4,8,16}$ , and hidden dimensions in ${128,256,512}$ . For ESA, we generally focused on models with 6 to 10 layers, with SABs at the start and end and MABs in the middle. Hidden dimensions are selected from ${256,512}$ and the number of attention heads from ${8,16,32}$ . We evaluate pre-LN and post-LN architectures, standard and gated MLPs, and different MLP hidden dimensions and number of MLP layers on a dataset-by-dataset basis. The best configuration is selected based on the validation loss and results are reported on the test set from 5 different runs. Based on recent reports on the performance of GNNs [30], we augmented all 6 GNN baselines with residual connections and normalisation layers for each graph convolutional layer. These strategies are not part of the original message passing specification but lead to substantial uplifts. For datasets with established splits, such as cifar10, mnist, or zinc, we use the available splits. For dockstring, only train and test splits are available, so we randomly extract 20,000 train molecules for validation. Otherwise, we generate our own splits with a train/validation/test ratio of $80%/10%/10%$ . All models are trained and evaluated using mixed-precision training with the bfloat16 tensor format.

为了公平全面地评估,我们采用网格搜索方法和一组合理的超参数对每个数据集上的每种算法进行调优。这些参数包括层数、GAT(v2)和图Transformer的注意力头数量、dropout率以及隐藏层维度。对于图级GNN,我们评估了4至6层、隐藏维度在${128,256,512}$、GAT头数在${8,16}$、GAT dropout率在${0,0.2}$的配置组合。针对节点级GNN,由于这类任务对过拟合更敏感,我们将搜索范围调整为层数${1,2,4,6}$、隐藏维度${64,128,256}$,保持相同的GAT头数和dropout设置,并在每个GNN层后添加${0,0.2}$的dropout。

对于Graphormer、TokenGT和GraphGPS,我们评估了层数${4,6,8,10}$、注意力头数${4,8,16}$、隐藏维度${128,256,512}$的模型配置。在ESA模型中,我们主要关注6至10层的架构,采用首尾SAB层和中间MAB层的设计,隐藏维度选自${256,512}$,注意力头数选自${8,16,32}$。我们逐数据集评估了pre-LN与post-LN结构、标准与门控MLP,以及不同MLP隐藏维度和层数的组合。最佳配置根据验证损失选取,测试集结果报告5次运行的平均值。

基于近期关于GNN性能的研究[30],我们为所有6个GNN基线模型在每个图卷积层添加了残差连接和归一化层。这些策略虽不属于原始消息传递规范,但能显著提升性能。对于已划分数据集(如cifar10、mnist或zinc),我们直接使用现有划分。dockstring数据集仅提供训练/测试划分,因此我们随机抽取20,000个训练分子作为验证集。其余情况我们按$80%/10%/10%$的比例生成训练/验证/测试划分。所有模型均采用bfloat16张量格式的混合精度训练进行训练和评估。

SI 7.2 Metrics

SI 7.2 指标

Given the scale of our evaluation, it is crucial to use an appropriate selection of performance metrics. To this end, we selected the metrics according to established and recent literature. For classification tasks, there is a growing consensus that Matthew’s correlation coefficient (MCC) is the preferred metric over alternatives such as accuracy, F-score, and the area under the receiver operating characteristic curve (AUROC or ROC-AUC) [71–74]. The AUROC in particular has been shown to be problematic [75–77]. The MCC is an informative measure of a classifier’s performance as it summarises the four basic rates of a confusion matrix: sensitivity, specificity, precision, and negative predictive value [75]. Similarly, the $\mathrm{R^{2}}$ has been proven to be more informative than alternatives such as the mean absolute or squared errors for regression tasks [78]. Thus, our first choices for reporting results are the MCC and $\mathrm{{R^{2}}}$ , depending on the task. For comparison with leader board results and for specialised fields such as quantum mechanics, we also report comparable metrics (i.e., accuracy, mean absolute error, or root mean squared error).

考虑到评估的规模,选择合适的性能指标至关重要。为此,我们根据既有文献和最新研究选取了相应指标。在分类任务中,越来越多的共识认为马修斯相关系数 (MCC) 优于准确率、F值和受试者工作特征曲线下面积 (AUROC或ROC-AUC) 等替代指标 [71-74]。研究证明AUROC尤其存在缺陷 [75-77]。MCC能全面反映分类器性能,因为它综合了混淆矩阵的四个基本比率:灵敏度、特异度、精确度和阴性预测值 [75]。同样,对于回归任务,$\mathrm{R^{2}}$ 也被证明比平均绝对误差或均方误差等替代指标更具信息量 [78]。因此,我们优先根据任务类型选用MCC和$\mathrm{{R^{2}}}$报告结果。为便于与排行榜结果对比及适应量子力学等专业领域,我们也会报告可比指标(如准确率、平均绝对误差或均方根误差)。

SI 7.3 Baselines

SI 7.3 基线方法

We include classic message passing baselines in the form of GCN, GAT, and GIN due to their recent resurgence against sophisticated graph transformers and their widespread use. We also include the improved GATv2 [24] to complement GAT, and PNA for being neglected in other works despite its remarkable empirical performance. We complete the message passing baselines by including DropGNN [79], a family of provably expressive GNNs that can solve tasks beyond 1-WL in an efficient manner by randomly dropping nodes. As in the original paper, we use GIN as the main underlying mechanism and label this technique DropGIN. Regarding transformer baselines, we select Graphormer, TokenGT, and GraphGPS, not only due to their widespread use, but also due to generally outperforming previous generation graph transformers such as SAN. This selection is balanced, in the sense that Graphormer and TokenGT are part of a class of algorithms that focuses on representing the graph structure through encodings and token identifiers, while GraphGPS relies on GNNs and is thus a hybrid approach.

我们选取了GCN、GAT和GIN作为经典消息传递基线方法,因其近期在对抗复杂图Transformer模型时的复兴态势及广泛适用性。同时补充了改进版GATv2 [24]以完善GAT基线,并纳入PNA方法——尽管其卓越的实证性能常被其他研究忽视。通过引入DropGNN [79]系列(一组可证明具有强表达力的图神经网络,能通过随机丢弃节点高效解决超越1-WL难度的任务),我们完成了消息传递基线的构建。如原论文所述,我们采用GIN作为基础机制并将该技术标记为DropGIN。在Transformer基线方面,我们选择了Graphormer、TokenGT和GraphGPS,不仅因其广泛使用,更因其普遍优于SAN等前代图Transformer模型。这一选择具有平衡性:Graphormer与TokenGT属于通过编码和Token标识符表征图结构的算法类别,而GraphGPS依赖图神经网络,属于混合方法。

SI 8 Infected graph generation

SI 8 感染图生成

The infected graphs are generated using the Infection Data set from PyTorch Geometric. We generated two Erds-Rényi (ER) graphs with different sizes:

受感染的图是使用 PyTorch Geometric 中的 Infection 数据集生成的。我们生成了两个不同大小的 Erdős-Rényi (ER) 图:

These settings ensure a relatively balanced classification task for both graph sizes.

这些设置确保了对两种图规模都相对平衡的分类任务。

SI 9 Adaptations for Open Catalyst Project

SI 9 针对开放催化剂项目 (Open Catalyst Project) 的适配

Extending a given model to work with 3D data is not trivial, as demonstrated by the follow-up paper dedicated to extending and benchmarking Graphormer on 3D molecular problems [62]. As described in that paper, for the Open Catalyst Project (OCP) data certain pre-processing steps are taken to ensure satisfactory performance. Concretely, a set of Gaussian basis functions is used to encode atomic distances, which are not used in their raw form. The idea of encoding raw quantities to achieve expressive and orthogonal representations has also been studied by Gasteiger et al. for DimeNet [5], taking things further and using Bessel functions. This idea is prevalent in the literature and shows some of the complications of working with 3D coordinates.

将给定模型扩展到处理3D数据并非易事,后续论文专门研究了如何扩展Graphormer并在3D分子问题上进行基准测试[62],就证明了这一点。如该论文所述,对于开放催化剂项目(OCP)数据,需要采取某些预处理步骤以确保性能达标。具体而言,使用一组高斯基函数对原子距离进行编码,而非直接使用原始距离值。Gasteiger等人在DimeNet[5]中也研究了通过编码原始量来获得表达性强且正交表示的想法,并进一步使用了贝塞尔函数。这一思路在文献中很常见,也展示了处理3D坐标时的一些复杂性。

In addition, OCP exhibits several unique characteristics that must be accounted for to extract the most performance out of any given model. One such property is given by periodic boundary conditions (common for crystal systems), requiring a dedicated pre-processing step. Another characteristic is the presence of 3 types of atoms: sub-surface slab atoms, surface slab atoms, and adsorbate atoms, which must be distinguished by the model. Finally, the task chosen in the benchmarking Graphormer paper is not only relaxed energy prediction, but also relaxed structure prediction, entailing the prediction of new coordinates for all atoms. This again leverages additional data in the dataset and can be considered a task with synergistic positive effects for relaxed energy prediction.

此外,OCP 还具有若干独特特性,必须加以考虑才能充分发挥任何给定模型的性能。其中一个特性源于周期性边界条件(晶体系统的常见设定),这需要专门的预处理步骤。另一特点是存在 3 种原子类型:亚表面板原子、表面板原子和吸附原子,模型必须对它们加以区分。最后,Graphormer 基准测试论文中选择的任务不仅包含弛豫能量预测,还包含弛豫结构预测——即预测所有原子的新坐标。这再次利用了数据集中的额外数据,可视为对弛豫能量预测具有协同促进效应的任务。

On top of this, the 3D implementation of Graphormer is unavailable on hugging face (the version used throughout the paper). We chose to perform experiments using a 10K training set that is provided as part of OCP, and the same validation set of size 25K. The entire OCP dataset consists of around 500K dense catalytic structures and training both Graphormer and ESA/NSA on this task would entail a computational effort larger than for any other evaluated dataset. To complicate things further, each ‘graph’ is dense, leading to a significant GPU memory burden, such that only small batch sizes are possible even for high-end GPUs.

除此之外,Graphormer 的 3D 实现版本在 hugging face 上不可用(论文中使用的版本)。我们选择使用 OCP 提供的 10K 训练集和相同规模的 25K 验证集进行实验。整个 OCP 数据集包含约 500K 个密集催化结构,在此任务上训练 Graphormer 和 ESA/NSA 所需的计算量远超其他评估数据集。更复杂的是,每个"图"都是密集的,导致 GPU 内存负担显著增加,即使高端 GPU 也只能支持小批量处理。

Moreover, with higher batch sizes we have hit a limit of PyTorch: the nonzero() tensor method that is used as part of the mask computation is not defined for tensors with more elements than the 32-bit integer limit. While this operation can be chunked in smaller tensors, this induces a significant slowdown while training. Overall, this software limitation highlights the fact that current libraries are not optimised for masked attention. We present other software limitations in SI 5, as well as possible solutions, and we believe that ESA can be significantly optimised with careful software (and even hardware) design.

此外,在更大的批次规模下,我们遇到了PyTorch的极限:用于掩码计算的nonzero()张量方法无法处理元素数量超过32位整数限制的张量。虽然可以通过分块处理较小张量来执行此操作,但这会导致训练速度显著下降。总体而言,这一软件限制凸显出现有库未对掩码注意力进行优化的事实。我们在SI 5中列出了其他软件限制及潜在解决方案,并相信通过细致的软件(甚至硬件)设计能大幅优化ESA。

For the 10K train $^+$ 25K validation task, we have adapted our method to use the 3D pre-processing described above, and modified Graphormer to perform only relaxed energy prediction (i.e., without relaxed structure prediction). We used a batch size of 16 for both models, and roughly equivalent settings between NSA and Graphormer, where possible, including 4 layers, 16 attention heads, an embedding/hidden size of 256, and the same learning rate (1e-4). Both methods used mixed precision training. Here, we used the Graphormer 3D implementation from the official repository. We also note that structural information (in the form of edge index tensors in PyTorch Geometric) is provided in the OCP dataset. They are derived in a similar way to PyTorch Geometric’s radius graph() function, which connects points based on a distance cutoff. We can use this information for ESA/NSA.

对于10K训练$^+$25K验证任务,我们调整了方法以采用上述3D预处理流程,并修改Graphormer仅执行弛豫能量预测(即不包含弛豫结构预测)。两种模型均采用16的批次大小,并在NSA与Graphormer之间尽可能保持等效配置,包括4层网络结构、16个注意力头、256维嵌入/隐藏层尺寸以及相同学习率(1e-4)。两种方法均采用混合精度训练。此处我们使用官方代码库中的Graphormer 3D实现。需特别说明的是,OCP数据集中提供了结构信息(以PyTorch Geometric的边索引张量形式存在),其生成方式与PyTorch Geometric的radius graph()函数类似,通过距离截断值连接各点。该信息可用于ESA/NSA计算。

SI 10 Experimental platform

SI 10 实验平台

Representative versions of the software used as part of this paper include Python 3.11, PyTorch version 2.5.1 with CUDA 12.1, PyTorch Geometric 2.5.3 and 2.6.0, PyTorch Lightning 2.4.0, hugging face transformers version 4.35.2, and xFormers version 0.0.27. It is worth noting that attention masking and efficient implementations of attention are early features that are advancing quickly. This means that their behaviour might change unexpectedly and there might be bugs. For example, PyTorch 2.1.1 recently fixed a bug that concerned non-contiguous custom attention masks in the scaled dot product attention function.

本文使用的代表性软件版本包括Python语言 3.11、支持CUDA 12.1的PyTorch 2.5.1、PyTorch Geometric 2.5.3与2.6.0、PyTorch Lightning 2.4.0、hugging face transformers 4.35.2以及xFormers 0.0.27。需注意的是,注意力掩码(attention masking)和注意力机制的高效实现属于快速迭代的前沿功能,其行为可能出现意外变化并存在潜在缺陷。例如PyTorch 2.1.1近期修复了缩放点积注意力函数(scaled dot product attention)中非连续自定义注意力掩码的一个错误。

In terms of hardware, the GPUs used include an NVIDIA RTX 3090 with 24GB VRAM, NVIDIA V100 with 16GB or 32GB of VRAM, and NVIDIA A100 with 40GB and 80GB of VRAM. Recent, efficient implementations of attention are optimised for the newest GPU architectures, generally starting from Ampere (RTX 3090 and A100).

在硬件方面,使用的GPU包括配备24GB显存的NVIDIA RTX 3090、配备16GB或32GB显存的NVIDIA V100,以及配备40GB和80GB显存的NVIDIA A100。近期高效的注意力机制实现针对最新GPU架构进行了优化,通常从安培架构(RTX 3090和A100)起步。

SI 11 Dataset statistics

SI 11 数据集统计

We present a summary of all the used datasets, along with their size and the maximum number of nodes and edges encountered in a graph in the dataset (Supplementary Table 11). The last two are important as they determine the shape of the mask and of the inputs for the attention blocks. Technically, we require that the maximum number of nodes/edges is determined per batch and the tensors to be padded accordingly. This per-batch maximum is lower than the dataset maximum for most batches. However, certain operations such as layer normalisation, if performed over the last two dimensions, require a constant value. To enable this, we use the dataset maximum.

我们汇总了所有使用的数据集及其规模,以及数据集中单个图(graph)的最大节点数和边数(补充表11)。后两项指标至关重要,因为它们决定了注意力模块输入数据的掩码形状和输入张量维度。从技术实现角度,我们要求每批次(batch)动态确定最大节点数/边数,并对张量进行相应填充(padding)。对于大多数批次而言,这种逐批次最大值会低于数据集整体最大值。但某些操作(如对最后两个维度执行的层归一化(layer normalisation))需要固定维度值,为此我们采用数据集整体最大值。

Supplementary Table 11: Summary of used datasets, their size, and the maximum number of nodes ( $\mathbf{N}$ ) and edges $(\mathbf{E})$ ) seen in a graph in the dataset.

表 11: 所用数据集及其规模、图中最大节点数 ( $\mathbf{N}$ ) 和边数 $(\mathbf{E})$ ) 的汇总。

Dataset Size N E
B PEPT-STRUCT 15 535 444 928
LRG PEPT-FUNC 15535 444 928
PPI 24 3480 106 754
CORA 1 2708 10556
CITESEER 1 3327 9104
ROMANEMPIRE 1 22662 65854
AMAZONRATINGS 1 24 492 186 100
NODE MINESWEEPER 1 10000 78804
TOLOKERS 1 11758 1038000
SQUIRREL 1 2223 93996
CHAMELEON 1 890 17708
INFECTED 15000 1 15 000 20048
INFECTED 30000 1 30000 45 258
FREESOLV 642 44 92
ET LIPO 4200 216 438
MOLN1 ESOL 1128 119 252
BBBP 2039 269 562
BACE 1513 184 376
HIV 41 127 438 882
MOL ZINC 249 456 38 90
PCQM4MV2 3452151 51 118
ION NC11 4110 111 238
NCI09 4127 111 238
CV MNIST 70000 75 600
CIFAR10 60000 150 1200
B10INF ENZYMES 600 126 298
PROTEINS 1113 620 2098
DD 1178 5748 28534
SYNTHETIC 300 100 392
SYNTH SYNTHETIC NEW 300 100 100 396 424
SYNTHIE 400
IMDB-BINARY 1000 136 2498
SOCIAL IMDB-MULTI 1500 89 2934
TWITCH EGOS 127094 52 1572
REDDIT THR. 203088 97 370
QM9 133885 29 56
DOCKSTRING 260060 164 342
MALNETTINY 5000 4994 20 096
OPEN CATALYST PROJECT 35 000 334 11 094
阅读全文(20积分)