SHARPNESS-AWARE MINIMIZATION FOR EFFICIENTLY IMPROVING GENERALIZATION
基于锐度感知最小化的高效泛化提升方法
ABSTRACT
摘要
In today’s heavily over parameterized models, the value of the training loss provides few guarantees on model generalization ability. Indeed, optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by prior work connecting the geometry of the loss landscape and generalization, we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a minmax optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model genera liz ation across a variety of benchmark datasets (e.g., CIFAR ${10,100}$ , ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several. Additionally, we find that SAM natively provides robustness to label noise on par with that provided by state-of-the-art procedures that specifically target learning with noisy labels. We open source our code at https: //github.com/google-research/sam.
在当今高度过参数化的模型中,训练损失值对模型泛化能力的保证十分有限。事实上,仅优化训练损失值(这是常见做法)很容易导致模型质量欠佳。受先前关于损失函数几何形状与泛化能力关联研究的启发,我们提出了一种新颖有效的方法,可同时最小化损失值和损失锐度。具体而言,我们的锐度感知最小化(Sharpness-Aware Minimization,SAM)方法会寻找处于均匀低损失邻域的参数;该方案形成了一个可通过梯度下降高效求解的极小极大优化问题。实验结果表明,SAM在多种基准数据集(如CIFAR ${10,100}$、ImageNet、微调任务)和模型上均提升了泛化能力,并在多项任务中创造了最新最优性能。此外,我们发现SAM天然具备与专门针对噪声标签学习的最先进方法相当的标签噪声鲁棒性。代码已开源:https://github.com/google-research/sam。
1 INTRODUCTION
1 引言
Modern machine learning’s success in achieving ever better performance on a wide range of tasks has relied in significant part on ever heavier over parameter iz ation, in conjunction with developing ever more effective training algorithms that are able to find parameters that generalize well. Indeed, many modern neural networks can easily memorize the training data and have the capacity to readily overfit (Zhang et al., 2016). Such heavy over parameter iz ation is currently required to achieve stateof-the-art results in a variety of domains (Tan & Le, 2019; Kolesnikov et al., 2020; Huang et al., 2018). In turn, it is essential that such models be trained using procedures that ensure that the parameters actually selected do in fact generalize beyond the training set.
现代机器学习之所以能在广泛任务上取得越来越好的性能,很大程度上依赖于日益严重的过参数化(overparameterization),同时配合开发出更高效的训练算法,这些算法能够找到具有良好泛化能力的参数。事实上,许多现代神经网络可以轻松记住训练数据,并具备容易过拟合的能力 (Zhang et al., 2016)。目前要在多个领域取得最先进的结果,这种严重的过参数化是必需的 (Tan & Le, 2019; Kolesnikov et al., 2020; Huang et al., 2018)。因此,必须通过训练流程确保实际选择的参数确实能够泛化到训练集之外。
Unfortunately, simply minimizing commonly used loss functions (e.g., cross-entropy) on the training set is typically not sufficient to achieve satisfactory generalization. The training loss landscapes of today’s models are commonly complex and non-convex, with a multiplicity of local and global minima, and with different global minima yielding models with different generalization abilities (Shirish Keskar et al., 2016). As a result, the choice of optimizer (and associated optimizer settings) from among the many available (e.g., stochastic gradient descent (Nesterov, 1983), Adam (Kingma & Ba, 2014), RMSProp (Hinton et al.), and others (Duchi et al., 2011; Dozat, 2016; Martens & Grosse, 2015)) has become an important design choice, though understanding of its relationship to model generalization remains nascent (Shirish Keskar et al., 2016; Wilson et al., 2017; Shirish Keskar & Socher, 2017; Agarwal et al., 2020; Jacot et al., 2018). Relatedly, a panoply of methods for modifying the training process have been proposed, including dropout (Srivastava et al., 2014), batch normalization (Ioffe & Szegedy, 2015), stochastic depth (Huang et al., 2016), data augmentation (Cubuk et al., 2018), and mixed sample augmentations (Zhang et al., 2017; Harris et al., 2020).
遗憾的是,仅通过在训练集上最小化常用损失函数(如交叉熵)通常不足以获得令人满意的泛化效果。当今模型的训练损失曲面往往复杂且非凸,存在大量局部极小值和全局极小值,而不同的全局极小值会产生具有不同泛化能力的模型(Shirish Keskar等人,2016)。因此,从众多可选的优化器(如随机梯度下降(Nesterov,1983)、Adam(Kingma & Ba,2014)、RMSProp(Hinton等人)及其他(Duchi等人,2011;Dozat,2016;Martens & Grosse,2015))中选择优化器(及相关设置)已成为关键设计决策,尽管对其与模型泛化关系的理解仍处于初级阶段(Shirish Keskar等人,2016;Wilson等人,2017;Shirish Keskar & Socher,2017;Agarwal等人,2020;Jacot等人,2018)。相关地,研究者们提出了多种改进训练过程的方法,包括dropout(Srivastava等人,2014)、批归一化(Ioffe & Szegedy,2015)、随机深度(Huang等人,2016)、数据增强(Cubuk等人,2018)以及混合样本增强(Zhang等人,2017;Harris等人,2020)。
Figure 1: (left) Error rate reduction obtained by switching to SAM. Each point is a different dataset / model / data augmentation. (middle) A sharp minimum to which a ResNet trained with SGD converged. (right) A wide minimum to which the same ResNet trained with SAM converged.
图 1: (左) 改用SAM (Sharpness-Aware Minimization) 获得的错误率降低。每个点代表不同的数据集/模型/数据增强方案。(中) 使用SGD训练的ResNet收敛到的尖锐最小值。(右) 使用SAM训练的相同ResNet收敛到的平坦最小值。
The connection between the geometry of the loss landscape—in particular, the flatness of minima— and generalization has been studied extensively from both theoretical and empirical perspectives (Shirish Keskar et al., 2016; Dziugaite & Roy, 2017; Jiang et al., 2019). While this connection has held the promise of enabling new approaches to model training that yield better generalization, practical efficient algorithms that specifically seek out flatter minima and furthermore effectively improve generalization on a range of state-of-the-art models have thus far been elusive (e.g., see (Chaudhari et al., 2016; Izmailov et al., 2018); we include a more detailed discussion of prior work in Section 5).
损失函数景观的几何特性(尤其是极小值的平坦度)与泛化能力之间的关联,已从理论和实证角度得到广泛研究(Shirish Keskar等,2016;Dziugaite & Roy,2017;Jiang等,2019)。尽管这种关联为开发具有更好泛化性能的模型训练方法提供了可能,但迄今为止,能够专门寻找更平坦极小值并有效提升各类前沿模型泛化能力的实用高效算法仍难以实现(例如参见Chaudhari等,2016;Izmailov等,2018;我们将在第5节详细讨论先前工作)。
We present here a new efficient, scalable, and effective approach to improving model generalization ability that directly leverages the geometry of the loss landscape and its connection to generalization, and is powerfully complementary to existing techniques. In particular, we make the following contributions:
我们在此提出一种高效、可扩展且有效的新方法,通过直接利用损失函数曲面的几何特性及其与泛化能力的关联来提升模型泛化性能,该方法与现有技术形成有力互补。具体贡献如下:
Section 2 below derives the SAM procedure and presents the resulting algorithm in full detail. Section 3 evaluates SAM empirically, and Section 4 further analyzes the connection between loss sharpness and generalization through the lens of SAM. Finally, we conclude with an overview of related work and a discussion of conclusions and future work in Sections 5 and 6, respectively.
以下第2节推导了SAM (Sharpness-Aware Minimization) 方法并完整呈现算法细节。第3节通过实验评估SAM的性能,第4节从SAM角度进一步分析损失函数锐度与泛化能力的关系。最后,第5节概述相关工作,第6节分别讨论结论与未来研究方向。
2 SHARPNESS-AWARE MINIMIZATION (SAM)
2 锐度感知最小化 (SAM)
Motivated by the connection between sharpness of the loss landscape and generalization, we propose a different approach: rather than seeking out parameter values $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}~}}}}$ that simply have low training loss value $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ , we seek out parameter values whose entire neighborhoods have uniformly low training loss value (equivalently, neighborhoods having both low loss and low curvature). The following theorem illustrates the motivation for this approach by bounding generalization ability in terms of neighborhood-wise training loss (full theorem statement and proof in Appendix A):
受损失函数曲面锐度与泛化能力之间关联的启发,我们提出了一种新方法:不同于仅寻找训练损失值 $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ 较低的参数值 $\mathbf{\nabla}_{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}~}}}}$ ,我们致力于寻找其整个邻域内训练损失值均保持较低水平的参数值(即同时具备低损失与低曲率的邻域)。以下定理通过邻域训练损失对泛化能力进行界定,阐明了该方法的理论依据(完整定理表述及证明见附录A):
Theorem (stated informally) 1. For any $\rho>0$ , with high probability over training set $s$ generated from distribution $\mathcal{D}$ ,
定理 (非正式表述) 1. 对于任意 $\rho>0$,在从分布 $\mathcal{D}$ 生成的训练集 $s$ 上,以高概率成立。
$$
L_{\mathcal{D}}(\pmb{w})\leq\operatorname*{max}{|\pmb{\epsilon}|{2}\leq\rho}L_{\mathcal{S}}(\pmb{w}+\pmb{\epsilon})+h(|\pmb{w}|_{2}^{2}/\rho^{2}),
$$
$$
L_{\mathcal{D}}(\pmb{w})\leq\operatorname*{max}{|\pmb{\epsilon}|{2}\leq\rho}L_{\mathcal{S}}(\pmb{w}+\pmb{\epsilon})+h(|\pmb{w}|_{2}^{2}/\rho^{2}),
$$
where $h:\mathbb{R}{+}\to\mathbb{R}{+}$ is a strictly increasing function (under some technical conditions on $L_{\mathcal{D}}(\boldsymbol{w}))$ .
其中 $h:\mathbb{R}{+}\to\mathbb{R}{+}$ 是一个严格递增函数 (在 $L_{\mathcal{D}}(\boldsymbol{w}))$ 满足某些技术性条件下) 。
To make explicit our sharpness term, we can rewrite the right hand side of the inequality above as
为了使我们的锐度项更加明确,可以将上述不等式右侧重写为
$$
\big[\underset{|\epsilon|{2}\le\rho}{\operatorname*{max}}L_{S}(\pmb{w}+\epsilon)-L_{S}(\pmb{w})\big]+L_{S}(\pmb{w})+h(|\pmb{w}|_{2}^{2}/\rho^{2}).
$$
$$
\big[\underset{|\epsilon|{2}\le\rho}{\operatorname*{max}}L_{S}(\pmb{w}+\epsilon)-L_{S}(\pmb{w})\big]+L_{S}(\pmb{w})+h(|\pmb{w}|_{2}^{2}/\rho^{2}).
$$
The term in square brackets captures the sharpness of $L_{S}$ at $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}}}}}$ by measuring how quickly the training loss can be increased by moving from $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}}}}}$ to a nearby parameter value; this sharpness term is then summed with the training loss value itself and a regularize r on the magnitude of $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}~}}}}$ . Given that the specific function $h$ is heavily influenced by the details of the proof, we substitute the second term with $\lambda||w||_{2}^{2}$ for a hyper parameter $\lambda$ , yielding a standard L2 regular iz ation term. Thus, inspired by the terms from the bound, we propose to select parameter values by solving the following SharpnessAware Minimization (SAM) problem:
方括号中的项通过测量从 $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}}}}}$ 移动到邻近参数值时训练损失 $L_{S}$ 的增长速度,来捕捉其在 $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}}}}}$ 处的锐度;该锐度项随后与训练损失值本身以及 $\mathbf{\nabla}{\mathbf{\overrightarrow{\mathbf{\vert~\mathbf{\nabla~}~}}}}$ 模长的正则化项相加。鉴于具体函数 $h$ 受证明细节影响较大,我们将第二项替换为超参数 $\lambda$ 控制的 $\lambda||w||_{2}^{2}$,从而得到标准的L2正则项。因此,受该界中项的启发,我们提出通过求解以下锐度感知最小化 (Sharpness-Aware Minimization, SAM) 问题来选择参数值:
$$
\operatorname*{min}{\pmb{w}}L_{S}^{S A M}(\pmb{w})+\lambda||\pmb{w}||{2}^{2}\mathrm{where}L_{S}^{S A M}(\pmb{w})\triangleq\operatorname*{max}{||\epsilon||{p}\leq\rho}L_{S}(\pmb{w}+\epsilon),
$$
$$
\operatorname*{min}{\pmb{w}}L_{S}^{S A M}(\pmb{w})+\lambda||\pmb{w}||{2}^{2}\mathrm{where}L_{S}^{S A M}(\pmb{w})\triangleq\operatorname*{max}{||\epsilon||{p}\leq\rho}L_{S}(\pmb{w}+\epsilon),
$$
where $\rho\geq0$ is a hyper parameter and $p\in[1,\infty]$ (we have generalized slightly from an L2-norm to a $p$ -norm in the maximization over $\epsilon$ , though we show empirically in appendix C.5 that $p=2$ is typically optimal). Figure 1 shows1 the loss landscape for a model that converged to minima found by minimizing either $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ or $L_{S}^{S A M}(w)$ , illustrating that the sharpness-aware loss prevents the model from converging to a sharp minimum.
其中 $\rho\geq0$ 是一个超参数,$p\in[1,\infty]$ (我们将最大化 $\epsilon$ 时的 L2 范数略微推广为 $p$ 范数,不过在附录 C.5 中通过实验表明 $p=2$ 通常是最优的)。图 1 展示了分别通过最小化 $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ 或 $L_{S}^{S A M}(w)$ 收敛到极小值时模型的损失函数景观,说明考虑锐度的损失函数能防止模型收敛到尖锐的极小值。
In order to minimize $L_{S}^{S A M}(w)$ , we derive an efficient and effective approximation to $\nabla_{\pmb{w}}L_{S}^{S A M}(\pmb{w})$ by different i Sating through the inner maximization, which in turn enables us to apply stochastic gradient descent directly to the SAM objective. Proceeding down this path, we first approximate the inner maximization problem via a first-order Taylor expansion of $L_{S}({\pmb w}+{\pmb\epsilon})$ w.r.t. $\epsilon$ around 0, obtaining
为了最小化 $L_{S}^{S A M}(w)$,我们通过对内部最大化问题进行微分,推导出 $\nabla_{\pmb{w}}L_{S}^{S A M}(\pmb{w})$ 的高效有效近似,从而能够直接将随机梯度下降应用于 SAM 目标。沿着这一思路,我们首先通过 $L_{S}({\pmb w}+{\pmb\epsilon})$ 关于 $\epsilon$ 在 0 附近的一阶泰勒展开来近似内部最大化问题,得到
$$
\epsilon^{}(w)\triangleq\operatorname*{argmax}{\Vert\epsilon\Vert_{p}\leq\rho}L_{S}(w+\epsilon)\approx\underset{\Vert\epsilon\Vert_{p}\leq\rho}{\arg\operatorname*{max}}L_{S}(w)+\epsilon^{T}\nabla_{w}L_{S}(w)=\underset{\Vert\epsilon\Vert_{p}\leq\rho}{\arg\operatorname*{max}}\epsilon^{T}\nabla_{w}L_{S}(w).
$$
$$
\epsilon^{}(w)\triangleq\operatorname*{argmax}{\Vert\epsilon\Vert_{p}\leq\rho}L_{S}(w+\epsilon)\approx\underset{\Vert\epsilon\Vert_{p}\leq\rho}{\arg\operatorname*{max}}L_{S}(w)+\epsilon^{T}\nabla_{w}L_{S}(w)=\underset{\Vert\epsilon\Vert_{p}\leq\rho}{\arg\operatorname*{max}}\epsilon^{T}\nabla_{w}L_{S}(w).
$$
In turn, the value $\hat{\pmb{\epsilon}}(\pmb{w})$ that solves this approximation is given by the solution to a classical dual norm problem $(|\cdot|^{q-1}$ denotes element wise absolute value and power)2:
反过来,求解该近似值的 $\hat{\pmb{\epsilon}}(\pmb{w})$ 由经典对偶范数问题的解给出 $(|\cdot|^{q-1}$ 表示逐元素绝对值和幂运算) [2]:
$$
\hat{\epsilon}(w)=\rho\operatorname{sign}\left(\nabla_{w}L_{S}(w)\right)|\nabla_{w}L_{S}(w)|^{q-1}/\bigg(|\nabla_{w}L_{S}(w)|_{q}^{q}\bigg)^{1/p}
$$
$$
\hat{\epsilon}(w)=\rho\operatorname{sign}\left(\nabla_{w}L_{S}(w)\right)|\nabla_{w}L_{S}(w)|^{q-1}/\bigg(|\nabla_{w}L_{S}(w)|_{q}^{q}\bigg)^{1/p}
$$
where $1/p+1/q=1$ . Substituting back into equation (1) and differentiating, we then have
其中 $1/p+1/q=1$ 。将其代回方程 (1) 并求导,可得
$$
\begin{array}{l}{{\displaystyle\nabla_{w}{\cal L}{\cal S}^{S A M}(w)\approx\nabla_{w}{\cal L}{\cal S}(w+\hat{\epsilon}(w))=\frac{d(w+\hat{\epsilon}(w))}{d w}\nabla_{w}{\cal L}{\cal S}(w)|{w+\hat{\epsilon}(w)}}}\ {{\displaystyle=\nabla_{w}{\cal L}{\cal S}(w)|{w+\hat{\epsilon}(w)}+\frac{d\hat{\epsilon}(w)}{d w}\nabla_{w}{\cal L}{\cal S}(w)|_{w+\hat{\epsilon}(w)}.}}\end{array}
$$
$$
\begin{array}{l}{{\displaystyle\nabla_{w}{\cal L}{\cal S}^{S A M}(w)\approx\nabla_{w}{\cal L}{\cal S}(w+\hat{\epsilon}(w))=\frac{d(w+\hat{\epsilon}(w))}{d w}\nabla_{w}{\cal L}{\cal S}(w)|{w+\hat{\epsilon}(w)}}}\ {{\displaystyle=\nabla_{w}{\cal L}{\cal S}(w)|{w+\hat{\epsilon}(w)}+\frac{d\hat{\epsilon}(w)}{d w}\nabla_{w}{\cal L}{\cal S}(w)|_{w+\hat{\epsilon}(w)}.}}\end{array}
$$
This approximation to $\nabla_{\pmb{w}}L_{S}^{S A M}(\pmb{w})$ can be straightforwardly computed via automatic different i ation, as implemented in common libraries such as JAX, TensorFlow, and PyTorch. Though this computation implicitly depends on the Hessian of $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ because $\hat{\epsilon}(w)$ is itself a function of $\operatorname{\omega}{{\boldsymbol w}}L_{S}({\boldsymbol w})$ , the Hessian enters only via Hessian-vector products, which can be computed tractably without materi ali zing the Hessian matrix. Nonetheless, to further accelerate the computation, we drop the second-order terms. obtaining our final gradient approximation:
对 $\nabla_{\pmb{w}}L_{S}^{S A M}(\pmb{w})$ 的近似可通过自动微分直接计算,如JAX、TensorFlow和PyTorch等常见库所实现。虽然该计算隐式依赖于 $L_{\mathcal{S}}({\boldsymbol{\mathbf{\mathit{w}}}})$ 的黑塞矩阵,因为 $\hat{\epsilon}(w)$ 本身是 $\operatorname{\omega}{{\boldsymbol w}}L_{S}({\boldsymbol w})$ 的函数,但黑塞矩阵仅通过黑塞-向量积参与计算,这种计算方式无需显式构建黑塞矩阵即可高效完成。为进一步加速计算,我们舍弃二阶项,最终得到梯度近似:
$$
\nabla_{\pmb{w}}L_{\pmb{\mathscr{S}}}^{S A M}(\pmb{w})\approx\nabla_{\pmb{w}}L_{\pmb{\mathscr{S}}}(\pmb{w})|_{\pmb{w}+\hat{\pmb{\epsilon}}(\pmb{w})}.
$$
$$
\nabla_{\pmb{w}}L_{\pmb{\mathscr{S}}}^{S A M}(\pmb{w})\approx\nabla_{\pmb{w}}L_{\pmb{\mathscr{S}}}(\pmb{w})|_{\pmb{w}+\hat{\pmb{\epsilon}}(\pmb{w})}.
$$
As shown by the results in Section 3, this approximation (without the second-order terms) yields an effective algorithm. In Appendix C.4, we additionally investigate the effect of instead including the second-order terms; in that initial experiment, including them surprisingly degrades performance, and further investigating these terms’ effect should be a priority in future work.
如第3节结果所示,这种近似方法(不含二阶项)能产生有效算法。在附录C.4中,我们还研究了包含二阶项的影响;在初步实验中,包含这些项反而会降低性能,未来工作中应优先研究这些项的影响。
We obtain the final SAM algorithm by applying a standard numerical optimizer such as stochastic gradient descent (SGD) to the SAM objective $L_{S}^{\l_{S A M}}(\boldsymbol{w})$ , using equation 3 to compute the requisite objective function gradients. Algorithm 1 gives pseudo-code for the full SAM algorithm, using SGD as the base optimizer, and Figure 2 schematically illustrates a single SAM parameter update.
我们通过将随机梯度下降 (SGD) 等标准数值优化器应用于 SAM 目标函数 $L_{S}^{\l_{S A M}}(\boldsymbol{w})$ ,并使用公式 3 计算所需的目标函数梯度,最终得到 SAM 算法。算法 1 给出了完整 SAM 算法的伪代码(以 SGD 为基础优化器),图 2 则示意性地展示了单次 SAM 参数更新过程。
Figure 2: Schematic of the SAM parameter update.
图 2: SAM参数更新示意图。
3 EMPIRICAL EVALUATION
3 实证评估
In order to assess SAM’s efficacy, we apply it to a range of different tasks, including image classification from scratch (including on CIFAR-10, CIFAR-100, and ImageNet), finetuning pretrained models, and learning with noisy labels. In all cases, we measure the benefit of using SAM by simply replacing the optimization procedure used to train existing models with SAM, and computing the resulting effect on model generalization. As seen below, SAM materially improves generalization performance in the vast majority of these cases.
为了评估SAM的有效性,我们将其应用于一系列不同任务,包括从零开始的图像分类(包括CIFAR-10、CIFAR-100和ImageNet)、微调预训练模型以及带噪声标签的学习。在所有情况下,我们通过简单地将现有模型的训练优化过程替换为SAM来衡量其优势,并计算对模型泛化能力的影响。如下所示,SAM在绝大多数情况下显著提升了泛化性能。
3.1 IMAGE CLASSIFICATION FROM SCRATCH
3.1 从零开始的图像分类
We first evaluate SAM’s impact on generalization for today’s state-of-the-art models on CIFAR-10 and CIFAR-100 (without pre training): Wide Res Nets with ShakeShake regular iz ation (Zagoruyko & Komodakis, 2016; Gastaldi, 2017) and PyramidNet with ShakeDrop regular iz ation (Han et al., 2016; Yamada et al., 2018). Note that some of these models have already been heavily tuned in prior work and include carefully chosen regular iz ation schemes to prevent over fitting; therefore, significantly improving their generalization is quite non-trivial. We have ensured that our implementations’ generalization performance in the absence of SAM matches or exceeds that reported in prior work (Cubuk et al., 2018; Lim et al., 2019)
我们首先评估了SAM在CIFAR-10和CIFAR-100数据集(未经预训练)上对当前最先进模型的泛化性能影响:采用ShakeShake正则化(Wide Res Nets)的模型(Zagoruyko & Komodakis, 2016; Gastaldi, 2017)以及采用ShakeDrop正则化的PyramidNet模型(Han et al., 2016; Yamada et al., 2018)。需要说明的是,这些模型中的部分已在先前工作中经过充分调优,并采用了精心设计的正则化方案以防止过拟合,因此要显著提升其泛化性能并非易事。我们已确保在不使用SAM的情况下,我们的实现方案的泛化性能达到或超越了先前工作(Cubuk et al., 2018; Lim et al., 2019)报告的结果。
All results use basic data augmentations (horizontal flip, padding by four pixels, and random crop). We also evaluate in the setting of more advanced data augmentation methods such as cutout regularization (Devries & Taylor, 2017) and Auto Augment (Cubuk et al., 2018), which are utilized by prior work to achieve state-of-the-art results.
所有结果均采用基础数据增强方法(水平翻转、四周填充四个像素及随机裁剪)。我们还评估了更先进的数据增强方法设置,如cutout正则化 (Devries & Taylor, 2017) 和Auto Augment (Cubuk et al., 2018),这些方法被先前研究用于实现最先进成果。
SAM has a single hyper parameter $\rho$ (the neighborhood size), which we tune via a grid search over ${0.01,0.02,0.05,0.{\bar{1}},0.{\bar{2}},0.5}$ using $10%$ of the training set as a validation set3. Please see appendix C.1 for the values of all hyper parameters and additional training details. As each SAM weight update requires two back propagation operations (one to compute $\hat{\pmb{\epsilon}}(\pmb{w})$ and another to compute the final gradient), we allow each non-SAM training run to execute twice as many epochs as each SAM training run, and we report the best score achieved by each non-SAM training run across either the standard epoch count or the doubled epoch count4. We run five independent replicas of each experimental condition for which we report results (each with independent weight initialization and data shuffling), reporting the resulting mean error (or accuracy) on the test set, and the associated $95%$ confidence interval. Our implementations utilize JAX (Bradbury et al., 2018), and we train all models on a single host having 8 NVidia $\mathrm{V}100\mathrm{GPUs}^{5}$ . To compute the SAM update when parallel i zing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update.
SAM 只有一个超参数 $\rho$(邻域大小),我们通过在 ${0.01,0.02,0.05,0.{\bar{1}},0.{\bar{2}},0.5}$ 上进行网格搜索来调整该参数,并使用训练集的 $10%$ 作为验证集。所有超参数的具体取值及其他训练细节请参阅附录 C.1。由于每次 SAM 权重更新需要两次反向传播操作(一次计算 $\hat{\pmb{\epsilon}}(\pmb{w})$,另一次计算最终梯度),我们让非 SAM 训练运行的 epoch 数是 SAM 训练的两倍,并报告非 SAM 训练在标准 epoch 数或双倍 epoch 数下取得的最佳分数。我们对每个实验条件进行了五次独立重复(每次使用独立的权重初始化和数据打乱),报告测试集上的平均误差(或准确率)及相应的 $95%$ 置信区间。我们的实现基于 JAX (Bradbury et al., 2018),所有模型均在配备 8 块 NVidia $\mathrm{V}100\mathrm{GPUs}^{5}$ 的单台主机上训练。在多加速器并行计算 SAM 更新时,我们将每个数据批次均匀分配给各加速器,在各加速器上独立计算 SAM 梯度,最后对子批次的 SAM 梯度取平均以获得最终的 SAM 更新。
As seen in Table 1, SAM improves generalization across all settings evaluated for CIFAR-10 and CIFAR-100. For example, SAM enables a simple WideResNet to attain $1.6%$ test error, versus $2.2%$ error without SAM. Such gains have previously been attainable only by using more complex model architectures (e.g., PyramidNet) and regular iz ation schemes (e.g., Shake-Shake, ShakeDrop); SAM provides an easily-implemented, model-independent alternative. Furthermore, SAM delivers improvements even when applied atop complex architectures that already use sophisticated regularization: for instance, applying SAM to a PyramidNet with ShakeDrop regular iz ation yields $10.3%$ error on CIFAR-100, which is, to our knowledge, a new state-of-the-art on this dataset without the use of additional data.
如表 1 所示,SAM (Sharpness-Aware Minimization) 在所有评估的 CIFAR-10 和 CIFAR-100 设置中都提高了泛化性能。例如,SAM 使简单的 WideResNet 实现了 1.6% 的测试错误率,而未使用 SAM 时为 2.2%。此前,这样的提升只能通过使用更复杂的模型架构(如 PyramidNet)和正则化方案(如 Shake-Shake、ShakeDrop)来实现;而 SAM 提供了一种易于实现、与模型无关的替代方案。此外,即使在已使用复杂正则化的架构上应用 SAM 仍能带来改进:例如,在采用 ShakeDrop 正则化的 PyramidNet 上应用 SAM 后,CIFAR-100 的错误率降至 10.3%,据我们所知,这是不使用额外数据时该数据集上的新最优结果。
Beyond CIFAR ${10,100}$ , we have also evaluated SAM on the SVHN (Netzer et al., 2011) and Fashion-MNIST datasets (Xiao et al., 2017). Once again, SAM enables a simple WideResNet to achieve accuracy at or above the state-of-the-art for these datasets: $0.99%$ error for SVHN, and $3.59%$ for Fashion-MNIST. Details are available in appendix B.1.
除了 CIFAR ${10,100}$ 之外,我们还在 SVHN (Netzer et al., 2011) 和 Fashion-MNIST (Xiao et al., 2017) 数据集上评估了 SAM。结果表明,SAM 能让简单的 WideResNet 在这些数据集上达到或超越当前最优水平:SVHN 错误率为 $0.99%$,Fashion-MNIST 错误率为 $3.59%$。详细结果见附录 B.1。
To assess SAM’s performance at larger scale, we apply it to ResNets (He et al., 2015) of different depths (50, 101, 152) trained on ImageNet (Deng et al., 2009). In this setting, following prior work (He et al., 2015; Szegedy et al., 2015), we resize and crop images to 224-pixel resolution, normalize them, and use batch size 4096, initial learning rate 1.0, cosine learning rate schedule, SGD optimizer with momentum 0.9, label smoothing of 0.1, and weight decay 0.0001. When applying SAM, we use $\rho=0.05$ (determined via a grid search on ResNet-50 trained for 100 epochs). We train all models on ImageNet for up to 400 epochs using a Google Cloud TPUv3 and report top-1 and top-5 test error rates for each experimental condition (mean and $95%$ confidence interval across 5 independent runs).
为了评估SAM在大规模场景下的性能,我们将其应用于不同深度(50、101、152层)的ResNet (He et al., 2015)模型在ImageNet (Deng et al., 2009)数据集上的训练。本实验遵循前人工作(He et al., 2015; Szegedy et al., 2015)的设置:将图像缩放裁剪至224像素分辨率并进行归一化处理,使用批量大小4096、初始学习率1.0、余弦学习率调度、动量0.9的SGD优化器、标签平滑系数0.1以及权重衰减0.0001。应用SAM时,我们设定$\rho=0.05$(该参数通过在ResNet-50上进行100轮训练的网格搜索确定)。所有模型均在Google Cloud TPUv3上训练400轮,每种实验条件下报告top-1和top-5测试错误率(取5次独立运行的平均值及$95%$置信区间)。
模型 | 数据增强 | CIFAR-10 SAM | CIFAR-10 SGD | CIFAR-100 SAM | CIFAR-100 SGD |
---|---|---|---|---|---|
WRN-28-10 (200 epochs) WRN-28-10 (200 epochs) WRN-28-10 (200 epochs) | Basic Cutout AA | 2.7±0.1 2.3±0.1 2.1±<0.1 | 3.5±0.1 2.6±0.1 2.3±0.1 | 16.5±0.2 14.9±0.2 13.6±0.2 | 18.8±0.2 16.9±0.1 |
WRN-28-10 (1800 epochs) WRN-28-10 (1800 epochs) WRN-28-10 (1800 epochs) | Basic Cutout AA | 2.4±0.1 2.1±0.1 1.6±0.1 | 3.5±0.1 2.7±0.1 2.2±<0.1 | 16.3±0.2 14.0±0.1 12.8±0.2 | 15.8±0.2 19.1±0.1 17.4±0.1 16.1±0.2 |
Shake-Shake (26 2x96d) Shake-Shake (26 2x96d) Shake-Shake (26 2x96d) | Basic Cutout AA | 2.3±<0.1 2.0±<0.1 1.6±<0.1 | 2.7±0.1 2.3±0.1 1.9±0.1 | 15.1±0.1 14.2±0.2 12.8±0.1 | 17.0±0.1 15.7±0.2 |
PyramidNet PyramidNet PyramidNet | Basic Cutout AA | 2.7±0.1 1.9±0.1 1.6±0.1 | 4.0±0.1 2.5±0.1 1.9±0.1 | 14.6±0.4 12.6±0.2 | 14.1±0.2 19.7±0.3 16.4±0.1 |
PyramidNet+ShakeDrop PyramidNet+ShakeDrop PyramidNet+ShakeDrop | Basic Cutout AA | 2.1±0.1 1.6±<0.1 1.4±<0.1 | 2.5±0.1 1.9±0.1 1.6±<0.1 | 11.6±0.1 13.3±0.2 11.3±0.1 10.3±0.1 | 14.6±0.1 14.5±0.1 11.8±0.2 10.6±0.1 |
As seen in Table 2, SAM again consistently improves performance, for example improving the ImageNet top-1 error rate of ResNet-152 from $20.3%$ to $18.4%$ . Furthermore, note that SAM enables increasing the number of training epochs while continuing to improve accuracy without over fitting. In contrast, the standard training procedure (without SAM) generally significantly overfits as training extends from 200 to 400 epochs.
如表 2 所示,SAM 再次持续提升性能,例如将 ResNet-152 在 ImageNet 上的 top-1 错误率从 $20.3%$ 降至 $18.4%$。此外需注意