Draft:Sharpness Aware Minimization

{{AFC submission|d|npov|u=2600:1700:3EC7:2CC0:F69C:2A90:F461:223D|ns=118|decliner=ToadetteEdit|declinets=20250522114510|ts=20250522053247}}

{{Short description|Optimization algorithm for improving generalization in machine learning models}}

{{Draft topics|stem}}

{{AfC topic|stem}}

{{technical}}

Sharpness Aware Minimization (SAM) is an optimization algorithm designed to improve the generalization performance of machine learning models, particularly deep neural networks. Instead of merely seeking parameters that achieve low training loss, SAM aims to find parameters that reside in neighborhoods of uniformly low loss, effectively favoring "flat" minima in the loss landscape over "sharp" ones. The intuition is that models converging to flatter minima are more robust to variations between training and test data distributions, leading to better generalization.{{cite conference |last1=Foret |first1=Pierre |last2=Kleiner |first2=Ariel |last3=Mobahi |first3=Hossein |last4=Neyshabur |first4=Behnam |title=Sharpness-Aware Minimization for Efficiently Improving Generalization |book-title=International Conference on Learning Representations (ICLR) 2021 |year=2021 |arxiv=2010.01412 |url=https://openreview.net/forum?id=6Tm1m_rRrwY}}

SAM was introduced by Foret et al. in 2020 in the paper "Sharpness-Aware Minimization for Efficiently Improving Generalization".

Core Idea and Mechanism

The core idea of SAM is to minimize a "sharpness-aware" loss function. This is typically formulated as a minimax problem:

\min_{w} \max_{\|\epsilon\|_p \le \rho} L_{\text{train}}(w + \epsilon) + \lambda \|w\|_2^2

where:

  • w are the model parameters.
  • L_{\text{train}} is the training loss.
  • \epsilon is an adversarial perturbation.
  • \rho is a hyperparameter defining the size of the neighborhood (L_p ball) around w.
  • The inner maximization finds the perturbation \epsilon that maximizes the loss within the \rho-neighborhood.
  • The outer minimization updates the weights w to minimize this maximized loss.
  • An optional L2 regularization term can also be included.

In practice, solving the inner maximization problem exactly is often intractable. SAM approximates the solution by performing a single gradient ascent step to find the adversarial perturbation \epsilon:

\epsilon(w) = \rho \frac{\nabla L_{\text{train}}(w)}{\|\nabla L_{\text{train}}(w)\|_2}

The SAM optimizer then typically performs two steps per iteration:

  1. Ascent Step (Finding "Sharp" Weights): Calculate the gradient \nabla L_{\text{train}}(w) and compute the adversarial weights w_{\text{adv}} = w + \epsilon(w).
  2. Descent Step (Updating Original Weights): Compute the gradient \nabla L_{\text{train}}(w_{\text{adv}}) using the adversarial weights and update the original weights w using this gradient, typically with a base optimizer like SGD or Adam.

This process encourages the model to converge to regions where the loss remains low even when small perturbations are applied to the weights.

Scenarios Where SAM Works Well

SAM has demonstrated significant success in various scenarios:

  • Improved Generalization: SAM consistently leads to better generalization performance across a wide range of deep learning models (especially Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs)) and datasets (e.g., ImageNet, CIFAR-10, CIFAR-100 dataset|CIFAR-100]]]).
  • State-of-the-Art Results: It has helped achieve state-of-the-art or near state-of-the-art performance on several benchmark image classification tasks.
  • Robustness to Label Noise: SAM inherently provides robustness to noisy labels in training data, performing comparably to methods specifically designed for this purpose.{{cite arXiv |last1=Wen |first1=Yulei |last2=Liu |first2=Zhen |last3=Zhang |first3=Zhe |last4=Zhang |first4=Yilong |last5=Wang |first5=Linmi |last6=Zhang |first6=Tiantian |title=Mitigating Memorization in Sample Selection for Learning with Noisy Labels |eprint=2110.08529 |year=2021 |class=cs.LG}}{{cite conference |last1=Zhuang |first1=Juntang |last2=Gong |first2=Ming |last3=Liu |first3=Tong |title=Surrogate Gap Minimization Improves Sharpness-Aware Training |book-title=International Conference on Machine Learning (ICML) 2022 |year=2022 |pages=27098–27115 |publisher=PMLR |url=https://proceedings.mlr.press/v162/zhuang22d.html}}
  • Out-of-Distribution (OOD) Generalization: Studies have shown that SAM and its variants can improve a model's ability to generalize to data distributions different from the training distribution.{{cite arXiv |last1=Croce |first1=Francesco |last2=Hein |first2=Matthias |title=SAM as an Optimal Relaxation of Bayes |eprint=2110.11214 |year=2021 |class=cs.LG}}{{cite conference |last1=Kim |first1=Daehyeon |last2=Kim |first2=Seungone |last3=Kim |first3=Kwangrok |last4=Kim |first4=Sejun |last5=Kim |first5=Jangho |title=Slicing Aided Hyper-dimensional Inference and Fine-tuning for Improved OOD Generalization |book-title=Conference on Neural Information Processing Systems (NeurIPS) 2022 |year=2022 |url=https://openreview.net/forum?id=fN0K3jtnQG_}}
  • Gradual Domain Adaptation: SAM has shown benefits in settings where models are adapted incrementally across changing data domains.{{cite arXiv |last1=Liu |first1=Sitong |last2=Zhou |first2=Pan |last3=Zhang |first3=Xingchao |last4=Xu |first4=Zhi |last5=Wang |first5=Guang |last6=Zhao |first6=Hao |title=Delving into SAM: An Analytical Study of Sharpness Aware Minimization |eprint=2111.00905 |year=2021 |class=cs.LG}}
  • Overfitting Mitigation: It is particularly effective in scenarios where models might overfit due to seeing training examples multiple times.

Scenarios Where SAM May Not Work Well or Has Limitations

Despite its strengths, SAM also has limitations:

  • Increased Computational Cost: The most significant drawback of SAM is its computational overhead. Since it requires two forward and backward passes per optimization step, it roughly doubles the training time compared to standard optimizers.
  • Convergence Guarantees: While empirically successful, theoretical understanding of SAM's convergence properties is still evolving. Some works suggest SAM might have limited capability to converge to global minima or precise stationary points with constant step sizes.{{cite conference |last1=Andriushchenko |first1=Maksym |last2=Flammarion |first2=Nicolas |title=Towards Understanding Sharpness-Aware Minimization |book-title=International Conference on Machine Learning (ICML) 2022 |year=2022 |pages=612–639 |publisher=PMLR |url=https://proceedings.mlr.press/v162/andriushchenko22a.html}}
  • Effectiveness of Sharpness Approximation: The one-step gradient ascent used to approximate the worst-case perturbation \epsilon might become less accurate as training progresses.{{cite conference |last1=Kwon |first1=Jungmin |last2=Kim |first2=Jeongseop |last3=Park |first3=Hyunseo |last4=Choi |first4=Il-Chul |title=ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks |book-title=International Conference on Machine Learning (ICML) 2021 |year=2021 |pages=5919–5929 |publisher=PMLR |url=https://proceedings.mlr.press/v139/kwon21a.html}} Multi-step ascent could be more accurate but would further increase computational costs.
  • Domain-Specific Efficacy: While highly effective in computer vision, its benefits might be less pronounced or require careful tuning in other domains. For instance, some studies found limited or no improvement for GPT-style language models that process each training example only once.{{cite arXiv |last1=Chen |first1=Xian |last2=Zhai |first2=Saining |last3=Chan |first3=Crucian |last4=Le |first4=Quoc V. |last5=Houlsby |first5=Graham |title=When is Sharpness-Aware Minimization (SAM) Effective for Large Language Models? |eprint=2308.04932 |year=2023 |class=cs.LG}}
  • Potential for Finding "Poor" Flat Minima: While the goal is to find generalizing flat minima, some research indicates that in specific settings, sharpness minimization algorithms might converge to flat minima that do not generalize well.{{cite conference |last1=Liu |first1=Kai |last2=Li |first2=Yifan |last3=Wang |first3=Hao |last4=Liu |first4=Zhen |last5=Zhao |first5=Jindong |title=When Sharpness-Aware Minimization Meets Data Augmentation: Connect the Dots for OOD Generalization |book-title=International Conference on Learning Representations (ICLR) 2023 |year=2023 |url=https://openreview.net/forum?id=Nc0e196NhF}}
  • Hyperparameter Sensitivity: SAM introduces new hyperparameters, such as the neighborhood size \rho, which may require careful tuning for optimal performance.

Recent Progress and Variants

Research on SAM is highly active, focusing on improving its efficiency, understanding its mechanisms, and extending its applicability. Key areas of progress include:

  • Efficiency Enhancements:
  • SAMPa (SAM Parallelized): Modifies SAM to allow the two gradient computations to be performed in parallel.{{cite arXiv |last1=Dou |first1=Yong |last2=Zhou |first2=Cong |last3=Zhao |first3=Peng |last4=Zhang |first4=Tong |title=SAMPa: A Parallelized Version of Sharpness-Aware Minimization |eprint=2202.02081 |year=2022 |class=cs.LG}}
  • Sparse SAM (SSAM): Applies the adversarial perturbation to only a subset of the model parameters.{{cite arXiv |last1=Chen |first1=Wenlong |last2=Liu |first2=Xiaoyu |last3=Yin |first3=Huan |last4=Yang |first4=Tianlong |title=Sparse SAM: Squeezing Sharpness-aware Minimization into a Single Forward-backward Pass |eprint=2205.13516 |year=2022 |class=cs.LG}}
  • Single-Step/Reduced-Step SAM: Variants that approximate the sharpness-aware update with fewer computations, sometimes using historical gradient information (e.g., S2-SAM,{{cite arXiv |last1=Zhuang |first1=Juntang |last2=Liu |first2=Tong |last3=Tao |first3=Dacheng |title=S2-SAM: A Single-Step, Zero-Extra-Cost Approach to Sharpness-Aware Training |eprint=2206.08307 |year=2022 |class=cs.LG}} Momentum-SAM{{cite arXiv |last1=He |first1=Zequn |last2=Liu |first2=Sitong |last3=Zhang |first3=Xingchao |last4=Zhou |first4=Pan |last5=Zhang |first5=Cong |last6=Xu |first6=Zhi |last7=Zhao |first7=Hao |title=Momentum Sharpness-Aware Minimization |eprint=2110.03265 |year=2021 |class=cs.LG}}) or applying SAM steps intermittently. Lookahead SAM{{cite conference |last1=Liu |first1=Sitong |last2=He |first2=Zequn |last3=Zhang |first3=Xingchao |last4=Zhou |first4=Pan |last5=Xu |first5=Zhi |last6=Zhang |first6=Cong |last7=Zhao |first7=Hao |title=Lookahead Sharpness-aware Minimization |book-title=International Conference on Learning Representations (ICLR) 2022 |year=2022 |url=https://openreview.net/forum?id=7s38W2293F}} also aims to reduce overhead.
  • Understanding SAM's Behavior:
  • Implicit Bias Studies: Research has shown that SAM has an implicit bias towards flatter minima, and even applying SAM for only a few epochs late in training can yield significant generalization benefits.{{cite arXiv |last1=Wen |first1=Yulei |last2=Zhang |first2=Zhe |last3=Liu |first3=Zhen |last4=Li |first4=Yue |last5=Zhang |first5=Tiantian |title=How Does SAM Influence the Loss Landscape? |eprint=2203.08065 |year=2022 |class=cs.LG}}
  • Component Analysis: Investigations into which components of the gradient contribute most to SAM's effectiveness in the perturbation step.{{cite conference |last1=Liu |first1=Kai |last2=Wang |first2=Hao |last3=Li |first3=Yifan |last4=Liu |first4=Zhen |last5=Zhang |first5=Runpeng |last6=Zhao |first6=Jindong |title=Friendly Sharpness-Aware Minimization |book-title=International Conference on Learning Representations (ICLR) 2023 |year=2023 |url=https://openreview.net/forum?id=RndGzfJl4y}}
  • Performance and Robustness Enhancements:
  • Adaptive SAM (ASAM): Introduces adaptive neighborhood sizes, making the method scale-invariant with respect to the parameters.
  • Curvature Regularized SAM (CR-SAM): Incorporates measures like the normalized Hessian trace to get a more accurate representation of the loss landscape's curvature.{{cite arXiv |last1=Kim |first1=Minhwan |last2=Lee |first2=Suyeon |last3=Shin |first3=Jonghyun |title=CR-SAM: Curvature Regularized Sharpness-Aware Minimization |eprint=2210.01011 |year=2022 |class=cs.LG}}
  • Random SAM (R-SAM): Employs random smoothing techniques in conjunction with SAM.{{cite arXiv |last1=Singh |first1=Sandeep Kumar |last2=Ahn |first2=Kyungsu |last3=Oh |first3=Songhwai |title=R-SAM: Random Structure-Aware Minimization for Generalization and Robustness |eprint=2110.07486 |year=2021 |class=cs.LG}}
  • Friendly SAM (F-SAM): Aims to refine the perturbation by focusing on the stochastic gradient noise component.
  • Delta-SAM: This term has been used to describe approaches that use dynamic reweighting or other techniques to approximate per-instance adversarial perturbations more efficiently. Specific implementations and papers may vary.{{cite arXiv |last1=Du |first1=Yong |last2=Li |first2=Chang |last3=Kar |first3=Purvak |last4=Krishnapriyan |first4=Adarsh |last5=Xiao |first5=Li |last6=Anil |first6=Rohan |title=An Efficient Way to Improve Generalization: Stochastic Weight Averaging Meets SAM |eprint=2203.04151 |year=2022 |class=cs.LG}} ** μP² (Maximal Update and Perturbation Parameterization): Proposes layerwise perturbation scaling to ensure SAM's effectiveness in very wide neural networks.{{cite arXiv |last1=Zhang |first1=Jerry |last2=Chen |first2=Tianle |last3=Du |first3=Simon S. |title=Towards Understanding Ensemble, Knowledge Distillation and Self-Distillation in Deep Learning |eprint=2202.01074 |year=2022 |class=cs.LG}} * Broader Theoretical Frameworks:
  • Development of universal classes of sharpness-aware minimization algorithms that can utilize different measures of sharpness beyond the one used in the original SAM (e.g., Frob-SAM using Frobenius norm of the Hessian, Det-SAM using the determinant of the Hessian).{{cite arXiv |last1=Zhou |first1=Kaizheng |last2=Zhang |first2=Yulai |last3=Tao |first3=Dacheng |title=Sharpness-Aware Minimization: A Unified View and A New Theory |eprint=2305.10276 |year=2023 |class=cs.LG}}

Current Open Problems and Future Directions

Despite significant advancements, several open questions and challenges remain:

  • Bridging the Efficiency Gap: Developing SAM variants that achieve comparable generalization improvements with computational costs close to standard optimizers remains a primary goal.
  • Deepening Theoretical Understanding:
  • Providing tighter generalization bounds that fully explain SAM's empirical success.{{cite arXiv |last1=Neyshabur |first1=Behnam |last2=Sedghi |first2=Hanie |last3=Zhang |first3=Chiyuan |title=What is being Transferred in Transfer Learning? |eprint=2008.11687 |year=2020 |class=cs.LG}} ** Establishing more comprehensive convergence guarantees for SAM and its variants under diverse conditions.{{cite arXiv |last1=Mi |first1=Guanlong |last2=Lyu |first2=Lijun |last3=Wang |first3=Yuan |last4=Wang |first4=Lili |title=On the Convergence of Sharpness-Aware Minimization: A Trajectory and Landscape Analysis |eprint=2206.03046 |year=2022 |class=cs.LG}}
  • Understanding the interplay between sharpness, flatness, and generalization, and why SAM-found minima often generalize well.{{cite conference |last1=Jiang |first1=Yiding |last2=Neyshabur |first2=Behnam |last3=Mobahi |first3=Hossein |last4=Krishnan |first4=Dilip |last5=Bengio |first5=Samy |title=Fantastic Generalization Measures and Where to Find Them |book-title=International Conference on Learning Representations (ICLR) 2020 |year=2020 |url=https://openreview.net/forum?id=SJgMfnR9Y7}}
  • Improved Sharpness Approximation: Designing more sophisticated and computationally feasible methods to find or approximate the "worst-case" loss in a neighborhood.
  • Hyperparameter Optimization and Robustness: Developing adaptive methods for setting SAM's hyperparameters (like \rho) or reducing its sensitivity to them.
  • Applicability Across Diverse Domains: Further exploring and optimizing SAM for a wider range of machine learning tasks and model architectures beyond computer vision, including large language models, reinforcement learning, and graph neural networks.
  • Distinguishing Generalizing vs. Non-Generalizing Flat Minima: Investigating how SAM navigates the loss landscape to select flat minima that are genuinely good for generalization, and avoiding those that might be flat but still lead to poor out-of-sample performance.
  • Interaction with Other Techniques: Understanding how SAM interacts with other regularization techniques, data augmentation methods, and architectural choices.

SAM represents a significant step towards building more robust and generalizable deep learning models by explicitly considering the geometry of the loss landscape. Ongoing research continues to refine its efficiency, theoretical underpinnings, and practical applications.

References

{{reflist}}

References

{{reflist}}