Code Monkey home page Code Monkey logo

sam's Introduction

(Adaptive) SAM Optimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization

~ in Pytorch ~



SAM simultaneously minimizes loss value and loss sharpness. In particular, it seeks parameters that lie in neighborhoods having uniformly low loss. SAM improves model generalization and yields SoTA performance for several datasets. Additionally, it provides robustness to label noise on par with that provided by SoTA procedures that specifically target learning with noisy labels.

This is an unofficial repository for Sharpness-Aware Minimization for Efficiently Improving Generalization and ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks. Implementation-wise, SAM class is a light wrapper that computes the regularized "sharpness-aware" gradient, which is used by the underlying optimizer (such as SGD with momentum). This repository also includes a simple WRN for Cifar10; as a proof-of-concept, it beats the performance of SGD with momentum on this dataset.

Loss landscape with and without SAM

ResNet loss landscape at the end of training with and without SAM. Sharpness-aware updates lead to a significantly wider minimum, which then leads to better generalization properties.


Usage

It should be straightforward to use SAM in your training pipeline. Just keep in mind that the training will run twice as slow, because SAM needs two forward-backward passes to estime the "sharpness-aware" gradient. If you're using gradient clipping, make sure to change only the magnitude of gradients, not their direction.

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:

  # first forward-backward pass
  loss = loss_function(output, model(input))  # use this loss for any training statistics
  loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  optimizer.second_step(zero_grad=True)
...

Alternative usage with a single closure-based step function. This alternative offers similar API to native PyTorch optimizers like LBFGS (kindly suggested by @rmcavoy):

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

Training tips

  • @hjq133: The suggested usage can potentially cause problems if you use batch normalization. The running statistics are computed in both forward passes, but they should be computed only for the first one. A possible solution is to set BN momentum to zero (kindly suggested by @ahmdtaha) to bypass the running statistics during the second pass. An example usage is on lines 51 and 58 in example/train.py:
for batch in dataset.train:
  inputs, targets = (b.to(device) for b in batch)

  # first forward-backward step
  enable_running_stats(model)  # <- this is the important line
  predictions = model(inputs)
  loss = smooth_crossentropy(predictions, targets)
  loss.mean().backward()
  optimizer.first_step(zero_grad=True)

  # second forward-backward step
  disable_running_stats(model)  # <- this is the important line
  smooth_crossentropy(model(inputs), targets).mean().backward()
  optimizer.second_step(zero_grad=True)
  • @evanatyourservice: If you plan to train on multiple GPUs, the paper states that "To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update." This can be achieved by the following code:
for input, output in data:
  # first forward-backward pass
  loss = loss_function(output, model(input))
  with model.no_sync():  # <- this is the important line
    loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()
  optimizer.second_step(zero_grad=True)
  • @evanatyourservice: Adaptive SAM reportedly performs better than the original SAM. The ASAM paper suggests to use higher rho for the adaptive updates (~10x larger)

  • @mlaves: LR scheduling should be either applied to the base optimizer or you should use SAM with a single step call (with a closure):

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=200)
  • @AlbertoSabater: Integration with Pytorch Lightning — you can write the training_step function as:
def training_step(self, batch, batch_idx):
    optimizer = self.optimizers()

    # first forward-backward pass
    loss_1 = self.compute_loss(batch)
    self.manual_backward(loss_1, optimizer)
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    loss_2 = self.compute_loss(batch)
    self.manual_backward(loss_2, optimizer)
    optimizer.second_step(zero_grad=True)

    return loss_1

Documentation

SAM.__init__

Argument Description
params (iterable) iterable of parameters to optimize or dicts defining parameter groups
base_optimizer (torch.optim.Optimizer) underlying optimizer that does the "sharpness-aware" update
rho (float, optional) size of the neighborhood for computing the max loss (default: 0.05)
adaptive (bool, optional) set this argument to True if you want to use an experimental implementation of element-wise Adaptive SAM (default: False)
**kwargs keyword arguments passed to the __init__ method of base_optimizer

SAM.first_step

Performs the first optimization step that finds the weights with the highest loss in the local rho-neighborhood.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.second_step

Performs the second optimization step that updates the original weights with the gradient from the (locally) highest point in the loss landscape.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.step

Performs both optimization steps in a single call. This function is an alternative to explicitly calling SAM.first_step and SAM.second_step.

Argument Description
closure (callable) the closure should do an additional full forward and backward pass on the optimized model (default: None)

Experiments

I've verified that SAM works on a simple WRN 16-8 model run on CIFAR10; you can replicate the experiment by running train.py. The Wide-ResNet is enhanced only by label smoothing and the most basic image augmentations with cutout, so the errors are higher than those in the SAM paper. Theoretically, you can get even lower errors by running for longer (1800 epochs instead of 200), because SAM shouldn't be as prone to overfitting. SAM uses rho=0.05, while ASAM is set to rho=2.0, as suggested by its authors.

Optimizer Test error rate
SGD + momentum 3.20 %
SAM + SGD + momentum 2.86 %
ASAM + SGD + momentum 2.55 %

Cite

Please cite the original authors if you use this optimizer in your work:

@inproceedings{foret2021sharpnessaware,
  title={Sharpness-aware Minimization for Efficiently Improving Generalization},
  author={Pierre Foret and Ariel Kleiner and Hossein Mobahi and Behnam Neyshabur},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=6Tm1mposlrM}
}
@inproceesings{pmlr-v139-kwon21b,
  title={ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks},
  author={Kwon, Jungmin and Kim, Jeongseop and Park, Hyunseo and Choi, In Kwon},
  booktitle ={Proceedings of the 38th International Conference on Machine Learning},
  pages={5905--5914},
  year={2021},
  editor={Meila, Marina and Zhang, Tong},
  volume={139},
  series={Proceedings of Machine Learning Research},
  month={18--24 Jul},
  publisher ={PMLR},
  pdf={http://proceedings.mlr.press/v139/kwon21b/kwon21b.pdf},
  url={https://proceedings.mlr.press/v139/kwon21b.html},
  abstract={Recently, learning algorithms motivated from sharpness of loss surface as an effective measure of generalization gap have shown state-of-the-art performances. Nevertheless, sharpness defined in a rigid region with a fixed radius, has a drawback in sensitivity to parameter re-scaling which leaves the loss unaffected, leading to weakening of the connection between sharpness and generalization gap. In this paper, we introduce the concept of adaptive sharpness which is scale-invariant and propose the corresponding generalization bound. We suggest a novel learning method, adaptive sharpness-aware minimization (ASAM), utilizing the proposed generalization bound. Experimental results in various benchmark datasets show that ASAM contributes to significant improvement of model generalization performance.}
}

sam's People

Contributors

bobo0810 avatar davda54 avatar timdarcet avatar wlaud1001 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sam's Issues

Reproducing WRN-28-10 (SAM) for SVHN dataset

I am trying to reproduce the results for WRN-28-10 (SAM) trained on 10-class classification SVHN dataset (Percentage Error 0.99) - https://paperswithcode.com/sota/image-classification-on-svhn

I'm able to train WRN-28-10 using https://github.com/hysts/pytorch_wrn (Modified the script to incorporate SAM into it)

I'm achieving a test accuracy of 93%. How can I replicate the SOTA Percentage Error 0.99 for WRN-28-10 (SAM). Which hyperparameters do I use?

Any help is appreciated!!

Question about the Implementation

hi!
thanks for your awesome work!

After comparing your code with google's official implementation,
I think there may be something wrong with your code about Batch Normalization's running mean and running var.

As there are 2 forward passes when using SAM, so the running mean and running var will be computed twice on the same data input in your code. But in google's implementation, they save the state(Parameters that will not be calculated gradient, including the running mean and running var) after the first forward pass as inner_state , and then return it as the final state in this step. that's to say they only calculated the running mean and running var one time.

here is google's code :

  def get_sam_gradient(model: flax.nn.Model, rho: float):
    """Returns the gradient of the SAM loss loss, updated state and logits.
    See https://arxiv.org/abs/2010.01412 for more details.
    Args:
      model: The model that we are training.
      rho: Size of the perturbation.
    """
    # compute gradient on the whole batch
    (_, (inner_state, _)), grad = jax.value_and_grad(
        lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model)
    if FLAGS.sync_perturbations:
      if FLAGS.inner_group_size is None:
        grad = jax.lax.pmean(grad, 'batch')
      else:
        grad = jax.lax.pmean(
            grad, 'batch',
            axis_index_groups=local_replica_groups(FLAGS.inner_group_size))
    grad = dual_vector(grad)
    noised_model = jax.tree_multimap(lambda a, b: a + rho * b,
                                     model, grad)
    (_, (_, logits)), grad = jax.value_and_grad(
        forward_and_loss, has_aux=True)(noised_model)
    return (inner_state, logits), grad

so, the little difference may affect the performance when evaluation, right?

Independently computing SAM gradient on multiple accelerators

Hello,

Thank you for this great implementation. I have a usage question having to do with multi-GPU training.

There is this line from the paper on page 5:

"To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch
evenly among the accelerators, independently compute the SAM gradient on each accelerator, and
average the resulting sub-batch SAM gradients to obtain the final SAM update."

Do you think doing the first SAM step in your implementation on each accelerator, then averaging the gradients from all the accelerators before the second step would accomplish this?

The code would be:

for examples, labels in loader:
    loss = loss_fn(network(examples), labels)
    loss.backward()
    optimizer.first_step(zero_grad=True)
    loss_fn(network(examples), targets).backward()
    reduce_gradients_from_all_accelerators()
    optimizer.second_step(zero_grad=True)

I believe not reducing the gradients between the accelerators before the first step would accomplish what they describe in the paper, but I'm looking for a second opinion. Any help is much appreciated.

Support for multi devices training

Hi

Thank you for your great Pytorch implementation! I'm wondering if there is a plan to support multi devices training (with data parallelism). As you might know, SAM benefits a lot from computing a different epsilon for each replica (section 4.1 of the paper on the m-sharpness), and the best generalization is often obtained when training on multiple GPUs or TPUs. It would be an awesome addition to this codebase.

pytorch lighting seems can't use?

i use SAM train with pytorch lighting.when i use muti-GPU,there is some error.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256]], which is output 73 of BroadcastBackward, is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

how to use with cuda amp

my code

scaler = amp.GradScaler()
...
epoch_losses.update(loss.item(), len(inputs))
scaler.scale(loss).backward()
scaler.update()

How should I apply the code you provided when I do it this way?

Import error on windows machine

On my local machine, the line
from sam import SAM
fails with the error:
ImportError: cannot import name 'Error' from partially initialized module 'sam' (most likely due to a circular import) (d:\...\python38\lib\site-packages\sam\__init__.py)

I'm on Windows 10 using python 3.8.6 and pytorch 1.7.0+cu110. I installed sam via pip. Do I totally overlook something here?

Accelerated training with floating point fp16

Thanks for the work!
I'd like to know if the the original code is also applicable to accelerated training, i.e. using automatic mixed precision like fp16. I tried to adopt SAM in my own training codes with apex fp16, but Nan loss happens and the computed grad norm is Nan. When I switch to fp32, it goes on well. Is it incompatible with fp16? What are the suggestions to make the code work with fp16? Thanks!

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

i encounter this problem,please help me ,thanks a lot

the code is :

compute output

    output = model(images)
    # loss = criterion(output, target)
    # SAM
    # first forward-backward pass
    loss = criterion(output, target)  # use this loss for any training statistics
    loss.backward()
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    criterion(output, target).backward()  # make sure to do a full forward pass
    optimizer.second_step(zero_grad=True)

    # measure accuracy and record loss
    acc1, acc5 = accuracy(output, target, topk=(1, 5))
    losses.update(loss.item(), images.size(0))
    top1.update(acc1[0], images.size(0))
    top5.update(acc5[0], images.size(0))

the problen is :
Traceback (most recent call last):
File "C:/softWare/SAM/main.py", line 557, in
main()
File "C:/softWare/SAM/main.py", line 156, in main
main_worker(args.gpu, ngpus_per_node, args)
File "C:/softWare/SAM/main.py", line 340, in main_worker
acc2, loss2=train(train_loader, model, criterion, optimizer, epoch, args, f)
File "C:/softWare/SAM/main.py", line 413, in train
criterion(output, target).backward() # make sure to do a full forward pass
File "C:\softWare\Python37\lib\site-packages\torch\tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "C:\softWare\Python37\lib\site-packages\torch\autograd_init_.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

Implementation error

Hi,
I am pretty sure that there is an error in the implementation of _grad_norm method.
If we look at the equation 2 in the original paper the denominator contains q-norm of vector grad(w) to the power of q and all that to the power 1/p. In the case of q=2 and p=2 we get (L2(grad(w)) ^ 2) ^ 0.5.
Your implementation stacks L2 norms of each parameter tensor, stacks them and calculate L2 norm of those stacked norm which is not the same as the above formula. You should ravel and concatenate all the parameter tensors and then calculate L2 norm two times.

Please check my reasoning but I am almost sure, that I am right.
Best Regards,
Jacek

AttributeError: 'NoneType' object has no attribute 'norm'

I am running your implementation in Colab, and everything is working well (though taking a lot of time even with a GPU runtime). However, when I try using sam in my own setup on the same cifar-10 I get the following error

`Training model...


AttributeError Traceback (most recent call last)
in

in train(train_loader, model, optimizer, epoch)
32 loss = smooth_crossentropy(output, target)
33 loss.mean().backward()
---> 34 optimizer.first_step(zero_grad=True)
35
36 smooth_crossentropy(model(input_), target).mean().backward()

~/anaconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_no_grad(*args, **kwargs)
47 def decorate_no_grad(*args, **kwargs):
48 with self:
---> 49 return func(*args, **kwargs)
50 return decorate_no_grad
51

~/Downloads/mila project/udacity_DP_FL/sam/sam.py in first_step(self, zero_grad)
14 @torch.no_grad()
15 def first_step(self, zero_grad=False):
---> 16 grad_norm = self._grad_norm()
17 for group in self.param_groups:
18 scale = group["rho"] / (grad_norm + 1e-12)

~/Downloads/mila project/udacity_DP_FL/sam/sam.py in _grad_norm(self)
44 torch.stack([
45 p.grad.norm(p=2)
---> 46 for group in self.param_groups for p in group["params"]
47 ]),
48 p=2

~/Downloads/mila project/udacity_DP_FL/sam/sam.py in (.0)
44 torch.stack([
45 p.grad.norm(p=2)
---> 46 for group in self.param_groups for p in group["params"]
47 ]),
48 p=2

AttributeError: 'NoneType' object has no attribute 'norm'`

Grad norm computation

Hi there! I was watching your code and i have a doubt regarding the _grad_nom(self) method: from the SAM paper the grad norm used to compute epsilon_hat(w) defined by the following equation
image supposing p=2 and then q=2 (1/p + 1/q = 1), would simply be the 2-norm of the gradients. So, shouldn't be the _grad_norm(self) method something like this:

@torch.no_grad()
def _grad_norm(self):
    # put everything on the same device, in case of model parallelism
    shared_device = self.param_groups[0]["params"][0].device
    norm = torch.sqrt(
        sum(
            [
                torch.sum(
                    torch.square(
                        p.grad
                        * (torch.abs(p.grad) if group["adaptive"] else 1.0)
                    )
                ).to(shared_device)
                for group in self.param_groups
                for p in group["params"]
                if p.grad is not None
            ]
        ),
    )
    return norm

I've already tested it with SAM obtaining an accuracy of 97.17% and I'm now running tests with ASAM

SAM with Mixed precision training

Hi. What a great work. Thanks for sharing your work.

I try to add your SAM optimizer in addition to the mixed precision training which is provided on Pytorch > 1.7.0.
There's no guideline about this. Here what I've been trying.

...
scaler.cale(loss).backward()
...
scaler.step(optimizer.first_step)
scaler.update()
optimizer.zero_grad()

loss = loss_fn(image_preds, image_labels)
scaler.scale(loss).backward()
scaler.step(optimizer.second_step)
scaler.update()
optimizer.zero_grad()

Please let me know how to incorporate this into the mixed precision training process.

Moreover, It'd be better to clarify what versions of Pytorch are available.

Correct way of lr scheduling

Hi,

is this the correct way of performing lr scheduling?

base_optimizer = torch.optim.Adam
optimizer = SAM(net.parameters(), base_optimizer, lr=3e-2, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=200)

Pytorch throws a warning if I perform lr scheduling directly on SAM.

About SGD parameters

Hi @davda54
Thank you for your great Pytorch implementation!

I want to know whether the momentum of SGD can be safely removed when using the SAM optimizer? That is, is momentum necessary for the SAM optimizer? Because momentum has a negative impact on my task.

base_optimizer = torch.optim.SGD
optimizer = SAM(base_optimizer, rho=0.05, lr=0.1, weight_decay=0.0005,momentum=0)

Thank you again and look forward to replying~

About the training loss

During the training process, is the first loss larger than the second loss? But my situation is the opposite.

AMP support

Hi, thank you very much for providing this great repo!!

I'd like to use your SAM implementation in some code where we use torch's native AMP functionality.

More specifically, we use the torch.cuda.amp.GradScaler class, which invokes the optimizer via scaler.step(optimizer). This is slightly tricky because your implemented optimizer differs from the expected interface since it has two optimization steps.

I'm sure there is an elegant way to still use your implementation in such settings. If you have any advice on this, I'd highly appreciate it!

m-sharpness

Hi, may i know how is the setting of m-sharpness ?
Is it to set the distributed training, where each GPU receives m data?
Thanks in advance.

Can SAM be used for GAN training?

Hi,

Thanks for sharing this good repo. I'm wondering that whether it is suitable to combine SAM with GAN training, like WGAN-GP loss? Normally WGAN-GP needs to compute gradient penalty, I'm not sure if this conflicts with SAM. Do you have any idea? Thanks!

'AdamW' object is not callable

how to save this problem?

def build_optimizer(config, model):
"""
Build optimizer, set weight decay of normalization to 0 by default.
"""

skip = {}
skip_keywords = {}
if hasattr(model, 'no_weight_decay'):
    skip = model.no_weight_decay()
if hasattr(model, 'no_weight_decay_keywords'):
    skip_keywords = model.no_weight_decay_keywords()
parameters = set_weight_decay(model, skip, skip_keywords)

opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
optimizer = None
if opt_lower == 'sgd':
    optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                          lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
elif opt_lower == 'adamw':
    optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
                            lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)

optimizer_sam = SAM(parameters, optimizer, momentum=0.9)

return optimizer_sam

Pytorch Lightning Integration

Hey @davda54 ,

This looks awesome. I wondered if you would be interested in making a PR contribution with SAM in Pytorch Lightning.
I am sure the community would love it.

Best,
T.C

Adaptive Sharpness Aware Minimization (ASAM) Implementation Help

Hello! I've been getting very good results using SAM. I stumbled on this paper, ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks, from Samsung Research improving upon SAM. ASAM is SAM but scale invariant, so it minimizes a relative neighborhood of weights vs. a perfect circle of weights and further improves generalization. Would you like to help me implement it?

This is what I have so far but I am not sure it is correct

They add an operator, T_w, described on page 8 in their paper, into the epsilon equation in a couple places. I am using their element-wise T_w. They show their epsilon equation on page 11 under where it says "for t = 0, 1, 2, · · · . Especially, if p = 2,", multiplying the grads by T_w before the norm calc, and multiplying the epsilon numerator by T_w squared.

This is my thought for the norm calc. I am not sure if I have it right, if I need to do something with torch.diag or not.

        norm = torch.norm(
                    torch.stack([
                        (torch.abs(p) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )

And this is my thought for epsilon, but I again am not sure if I need to do something with torch.diag, or need to flatten and concat things in a certain way or not.

e_w = torch.pow(p, 2) * p.grad * scale.to(p)

Any thoughts? I'm sorry if you are busy, but figured I'd ask! I've only tested on mnist so far, and did get a slightly improved result vs. SAM, so I think this is at least on the right track if not correct.

Question: Why step_lr?

Thank you for the great implementation!

Out of curiosity, why have you added your own step_lr scheduler and are not using the cosine annealing scheduler as suggested in the original paper?

Question about Cutout

Hello,

Thanks for the nice implementation! I have a question about the Cutout augmentation in the code below.

class Cutout:
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.half_size = size // 2
        self.p = p

    def __call__(self, image):
        if torch.rand([1]).item() > self.p: return image

        left = torch.randint(-self.half_size, image.shape[0] - self.half_size, [1]).item()
        top = torch.randint(-self.half_size, image.shape[1] - self.half_size, [1]).item()
        right = min(image.shape[0], left + self.size)
        bottom = min(image.shape[1], top + self.size)

        image[max(0,left):right, max(0,top):bottom, :] = 0
        return image

It seems like here you are assuming that the shape of image is [Height, Width, Channel], but actually here the input image is in [Channel, Height, Width]. I think the Cutout function need to be modified so that it assumes image to be the shape as [Channel, Height, Width]. Is my understanding correct here, or maybe I'm missing anything? Thanks!

The value of rho

Hi,
Thank you for your implement.
I have a question about rho. Should the value of rho be adjusted according to lr? If the rho is much larger than lr, perhaps it will cause the model to fail to converge. And if I use the lr scheduler to adjust lr, should rho to be adjusted?

Experiments on object detection task

Hi, is there anyone who applied SAM on the COCO object detection dataset? How do we set the hyper-parameter rho? I have tested rho=0.05 but got a performance degradation compared with SGD.

Multi GPU

Hi,

I am currently running some experiments with SAM. Previously I used Jax/Flax code provided by the authors of SAM. However for some reasons, I also have to run experiments with pytorch version. Then I found your pytorch implementation of SAM/ASAM.

First of all thank you for your work, it saved a lot of work and time for me. Especially listing known issues in Readme was very helpful.

I realized that Multi GPU version of SAM is yet to be implemented (and also not planned to be implemented in near future according to your comments in a closed issue). So I have following questions regarding Multi GPU settings.

  1. Is the experiment results you presented in Readme using only 1 GPU?
  2. In the Readme, there seems to be some remarks about Data Parallel by @evanatyourservice. However to my knowledge, DDP in pytorch automatically average out the gradients with .backward(). So rather than reducing all gradients after the second pass, shouldn't we NOT SYNC the gradients at the first pass? If not, it would be grateful if @evanatyourservice could provide the code for '''reduce_all_gradients''' part.

Once again thank you for the great work!

model does not converge after adding disable_bn and enable_bn

Hi. Following the suggestion posted in README, I disabled BatchNorm2d as below and then my model would not converge anymore? Did you literally mean BatchNorm2d or SyncBN? Hopefully you can find out what is wrong with this modification. Many thanks!

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)
           # first forward-backward step
            enable_bn(model)    # update BN in the wk
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_bn(model)   # do not update BN as we are in the perturbation point w_adv
            smooth_crossentropy(model(inputs), targets).mean().backward()
            optimizer.second_step(zero_grad=True)
            train_total_number += targets.shape[0]

training error rate or test error rate?

Hello, the error rate of cifar100 in your paper is 3.92%. Is this the training error rate, or is it the error rate obtained by applying the trained model to the test set? Looking forward to your reply.

rho for Adaptive Sharpness Aware Minimization (ASAM)

Hi.
This is Jungmin Kwon, one of the authors of Adaptive Sharpness Aware Minimization (ASAM).
We really appreciate your great implementation!
I have performed cifar10 tests with your code and we found that ASAM with rho=2.0 shows the best accuracy among [0.5, 1.0, 2.0, 5.0].
The test error rates obtained from the grid search are as follows:

rho Test error rate
0.5 2.75 %
1.0 2.69 %
2.0 2.55 %
5.0 2.90 %

In our implementation without bias (or beta for BatchNorm) normalization (https://github.com/SamsungLabs/ASAM), ASAM with rho=0.5 shows the best accuracy (2.37 % for WRN16-8), so we performed all the cifar10 tests with rho=0.5.
If you don't mind, could you update the table of test error rate with the result of rho=2.0 (2.55 %)?
Thank you.

Gradient Clipping

What is the proper way to perform gradient clipping with the SAM/ASAM optimizer? Should it be performed before both of the optimizer steps or just one of them?

Code example:

  # first forward-backward pass
  loss = loss_function(output, model(input))  # use this loss for any training statistics
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), v)
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  torch.nn.utils.clip_grad_norm_(model.parameters(), v)
  optimizer.second_step(zero_grad=True)

About gradient accumulation

@davda54 Thank you for providing the implementation of pytorch. The following is about the implementation of gradient accumulation. I am not sure whether the code is correct or not, so I hope to get your help~
93322be129d1a3aecc4a178eba549c4

Issue on weight decay

Thanks for a great job. It seems the first step does not consider the effect of weight decay, which is quite different from the TensorFlow version. Does this have a negative effect on the performance?

How do we give credit?

We're interested in integrating SAM into HomebrewNLP, and straight-up copying the code without copyright or notice doesn't sit right with me.
How would you prefer us to cite you?
I propose adding a header to sam.py as it's done in Shampoo.

How to resume SAM in pytorch?

Hi!
Thanks for your great work! When I follow your instructions and insert SAM in my code, it works well. But the training loss collapses every time I resume the model. So I suspect that I missed some critical steps to resume SAM. Do you have some suggestions?

TypeError: 'LightningSAM' object is not iterable. (with Pytorch Lightning.)

Hi, nice to meet you.
Your SAM is exciting paper. Very good job! Thank you!
I want to use SAM optimizer with pytorch lightning. But I had an error, and I can not find any information about it.
Could you kindly help me?

Here is my code:

class my_class(LightningModule):
    def __init__(self):
            self.SAM = True
            if self.SAM:
                self.automatic_optimization = False
    
    def training_step(self, batch, batch_idx):
            loss = self.shared_step(batch)
            
            if self.SAM:
                optimizer = self.optimizers()
                # first forward-backward pass
                self.manual_backward(loss, optimizer)
                optimizer.first_step(zero_grad=True)
    
                # second forward-backward pass
                loss_2 = self.shared_step(batch)
                self.manual_backward(loss_2, optimizer)
                optimizer.second_step(zero_grad=True)
            return loss
    
     def configure_optimizers(self):
            if self.SAM:
                base_optimizer = torch.optim.SGD
                optimizer = SAM(self.parameters(), base_optimizer, lr=0.01)
            return optimizer

SAM with PyTorch Lightning

Hi, I have implemented SAM to be used along with PyTorch Lightning, but it crashes at the beginning of the training:

This is a resume of the key lines of my LightningModule that involve SAM:

class pl_module(LightningModule):
    def __init__(self, params...):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

    def configure_optimizers(self):
        optim = SAM(self.parameters(), SGD, lr=0.01)
        return optim

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        def closure():
            loss = self.compute_loss(batch)
            self.manual_backward(loss, optimizer)
            return loss

        loss = self.compute_loss(batch)
        self.manual_backward(loss, optimizer)
        optimizer.optimizer.step(closure=closure)
        optimizer.optimizer.zero_grad()

        return loss 

This is the error:

Traceback (most recent call last):
  File "trainer_v2.py", line 409, in <module>
    train(folder_name, path_results, data_params, backbone_params, clf_params, contrastive_params, 
  File "trainer_v2.py", line 311, in train
    trainer.fit(model, dm)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
    self._run(model)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
    self.dispatch()
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
    return self.run_train()
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train
    self.train_loop.run_training_epoch()
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 499, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 714, in run_training_batch
    self.training_step_and_backward(
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 823, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 290, in training_step
    training_step_output = self.trainer.accelerator.training_step(args)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step
    return self.training_type_plugin.training_step(*args)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 155, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "trainer_v2.py", line 187, in training_step
    optimizer.second_step(zero_grad=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/asabater/2021_projects/evnet/sam.py", line 34, in second_step
    p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"
KeyError: 'e_w'

Variable `args.label_smoothing` not being passed to `smooth_cross_entropy()

ISSUE SUMMARY: Command-line argument label_smoothing is silently ignored and hard coded values are being used.

Command-line argument label_smoothing is set to a default value of 0.1

parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")

The default argument argument is set at 0.1 in smooth_cross_entropy(smoothing=0.1)

def smooth_crossentropy(pred, gold, smoothing=0.1):

As a result, any argument to label_smoothing is silently ignored. To fix this, please pass args.label_smoothing in to the function smooth_cross_entropy() when it is used in train.py

sam/example/train.py

Lines 53 to 75 in cdcbdc1

loss = smooth_crossentropy(predictions, targets)
loss.mean().backward()
optimizer.first_step(zero_grad=True)
# second forward-backward step
disable_running_stats(model)
smooth_crossentropy(model(inputs), targets).mean().backward()
optimizer.second_step(zero_grad=True)
with torch.no_grad():
correct = torch.argmax(predictions.data, 1) == targets
log(model, loss.cpu(), correct.cpu(), scheduler.lr())
scheduler(epoch)
model.eval()
log.eval(len_dataset=len(dataset.test))
with torch.no_grad():
for batch in dataset.test:
inputs, targets = (b.to(device) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)

Please see: https://github.com/davda54/sam/search?q=smoothing

Sharpness Aware Minimization is a normal (Closure) base optimizer

In your implementation of SAM, you split the optimization into first and second step and the code gives an error if "step" is called because it is not a normal optimizer. This doesn't necessarily have to be true because of the optional argument to step called "closure" which is a feature of optimizers like LBFGS that evaluate the loss and gradients at intermediate states (https://pytorch.org/docs/stable/optim.html?highlight=bfgs#torch.optim.LBFGS). If we use closure, the training routine for SAM looks like

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)

for input, output in data:
     def closure():
           optimizer.zero_grad()
           loss = loss_function(output, model(input))
           loss.backward()
           return loss

    loss = loss_function(output, model(input))
    loss.backward()
    optimizer.step(closure)
    optimizer.zero_grad()

where step(closure) is defined as,

    def step(self, closure=None):
          assert closure is not None, "Sharpness Aware Minimization requires closure but it was not provided"

          self.first_step()

          closure()

          self.second_step()

This change would bring it more in compliance with the standard pytorch optimizer implementation and make it easier to adopt since some fraction of people already implement closure based optimizers. At the very least, it gives people the option of which to adopt.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.