Code Monkey home page Code Monkey logo

Comments (10)

mtthss avatar mtthss commented on August 18, 2024 1

A lot of papers using RMSProp have hparams that were tuned with the original Tensorflow impl.

While it's true that TF1 optims initialised the stats with zero it seems like the community has moved in the opposite direction: for instance optax's initialisation with zeros matches both Keras, PyTorch, flax.optim, jax.experimental.optimizers. So I would be weary to overfit to the TF1 initialisation. What we could do is to have an initial_scale with default 0, that you can use to initialise stats with ones as in TF1 if needed.

Would that address your needs?

Regarding momentum

You can easily compose optax's gradient transformation into custom optimisers, by chaining for instance scale_by_rms and trace

from optax.

8bitmp3 avatar 8bitmp3 commented on August 18, 2024

@rwightman Sonnet (for TensorFlow) also has a momentum arg if I'm not mistaken https://github.com/deepmind/sonnet/blob/5a465696383f967d5bffb6599347d9e6c15cef4b/sonnet/src/optimizers/rmsprop.py 🤔 (though it's "not Nesterov momentum")

from optax.

8bitmp3 avatar 8bitmp3 commented on August 18, 2024

To provide more context, here's a code comparison:

Optax https://github.com/deepmind/optax/blob/70e0dd04c7bad9359b372260c56cd85c731b3cbe/optax/_src/alias.py

def rmsprop(learning_rate: ScalarOrSchedule,
            decay: float = 0.9,
            eps: float = 1e-8,
            centered: bool = False) -> GradientTransformation:
  if centered:
    return combine.chain(
        transform.scale_by_stddev(decay=decay, eps=eps),
        _scale_by_learning_rate(learning_rate),
    )
  return combine.chain(
      transform.scale_by_rms(decay=decay, eps=eps),
      _scale_by_learning_rate(learning_rate),
  )

Sonnet https://github.com/deepmind/sonnet/blob/5a465696383f967d5bffb6599347d9e6c15cef4b/sonnet/src/optimizers/rmsprop.py

def rmsprop_update(update, decay, learning_rate, epsilon, mu, mom, ms, mg):
  """Computes a single RMSProp update."""
  ms = tf.square(update) * (1. - decay) + ms * decay
  if mg is not None:  # centered
    mg = update * (1. - decay) + mg * decay
    denominator = ms - tf.square(mg) + epsilon
  else:
    denominator = ms + epsilon
  mom = (mu * mom) + (learning_rate * update * tf.math.rsqrt(denominator))
  return mom, ms, mg

class RMSProp(base.Optimizer):
  """RMSProp module.
  ...
  Attributes:
    learning_rate: Learning rate.
    decay: Learning rate decay over each update.
    momentum: Momentum scalar.
    epsilon: Small value to avoid zero denominator.
    centered: `True` if centered.
    mom: Accumulated mom for each parameter.
    ms: Accumulated ms for each parameter.
    mg: Accumulated mg for each parameter.
  """
...

from optax.

rwightman avatar rwightman commented on August 18, 2024

@8bitmp3 tx, I do have my own JAX impl too, but would like to not maintain my own, so seeing if there is a desire to have a TF1 variant supported here...

https://github.com/rwightman/efficientnet-jax/blob/master/jeffnet/linen/optim/rmsprop_tensorflow.py#L26-L55

class RMSPropTensorflow(OptimizerDef):
    """RMSProp optimizer that matches Tensorflow impl."""

    def __init__(self, learning_rate: float = None, beta1=0., beta2=0.9, eps=1e-8, weight_decay=0.):
        """Constructor for the RMSProp optimizer
        Args:
            learning_rate: the step size used to update the parameters.
            beta1 (float): gradient momentum factor (default: 0.)
            beta2 (float): discounting factor for the history/coming gradient magnitude (default: 0.9)
            eps: the term added to the gradient magnitude estimate for numerical stability.
        """
        hyper_params = _RMSPropHyperParams(learning_rate, beta1, beta2, eps, weight_decay)
        super().__init__(hyper_params)

    def init_param_state(self, param):
        """Initialize parameter state"""
        return _RMSPropTfParamState(jnp.ones_like(param), jnp.zeros_like(param))

    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        """Apply per-parameter gradients"""

        assert hyper_params.learning_rate is not None, 'no learning rate provided.'
        new_rms = hyper_params.beta2 * state.rms + (1.0 - hyper_params.beta2) * jnp.square(grad)
        new_mom = hyper_params.beta1 * state.mom + \
                  hyper_params.learning_rate * grad * lax.rsqrt(new_rms + hyper_params.eps)
        new_param = param - new_mom
        if hyper_params.weight_decay != 0.:
            new_param -= hyper_params.learning_rate * hyper_params.weight_decay * param
        new_state = _RMSPropTfParamState(new_rms, new_mom)

from optax.

8bitmp3 avatar 8bitmp3 commented on August 18, 2024

I wonder if (Nesterov) momentum is something you currently add outside RMSprop 🤷‍♂️ In /optax/optax/_src/transform.py

https://github.com/deepmind/optax/blob/01e846a4ca66aa1118428d152cf97ad7b7acaed9/optax/_src/transform.py

"""Gradient transformations."""
class TraceState(OptState):
  """Holds an aggregation of past updates."""
  trace: Params


def trace(decay: float, nesterov: bool) -> GradientTransformation:
  """Compute a trace of past updates.
  Args:
    decay: the decay rate for the tracing of past updates.
    nesterov: whether to use Nesterov momentum.
  Returns:
    An (init_fn, update_fn) tuple.
  """

  ...

Sorry I can't read code well 😄

from optax.

rwightman avatar rwightman commented on August 18, 2024

@mtthss yes, an initial_scale arg could achieve my goals

I had taken a look and it appears the snippet below would match RMSProp w/ momentum. I imagine someone else will request momentum as a n out of the box capability at some point as it's fairly common to have it enabled.

One thing to mention re the centered variant, there are two scale params, in TF1 one is init to zero and the other is one. I don't really care as much about this case since I never use it. Ignore, or initial_scale: Union[float, tuple[Float]]?

def rmsprop_momentum(
    learning_rate: ScalarOrSchedule,
    decay: float = 0.9,
    momentum: float = 0,
    eps: float = 1e-8,
    centered: bool = False,
    initial_scale: float = 1.) -> GradientTransformation:
  if centered:
    return combine.chain(
        transform.scale_by_stddev(decay=decay, eps=eps),
        _scale_by_learning_rate(learning_rate),
        transform.trace(decay=momentum, nesterov=False),
    )
  return combine.chain(
      transform.scale_by_rms(decay=decay, eps=eps, initial_scale=initial_scale),
      _scale_by_learning_rate(learning_rate),
      transform.trace(decay=momentum, nesterov=False),
  )

from optax.

mtthss avatar mtthss commented on August 18, 2024

The accumulators of rmsprop/adam have one the semantic of a mean and the of a scalefactor I think everyone sets the mean one to 0, its only thescaleone that is more controversial (either 0 or 1 depending on the framerwork) so I would suggestinitial_scale` to only refer to that one (i.e. be just a float), the name also then corresponds well to its semantics

Your code looks good,
want to put together a PR adding this option?

from optax.

rwightman avatar rwightman commented on August 18, 2024

@mtthss valid point, I have some variants of this testing right now, can create a PR.

Would you rather the PR be limited only to initial_scale arg addition, or also include the inclusion of trace to the rmsprop alias. With #52/#54 there wouldn't be any overhead when momentum defaults to 0.

There are a lot of popular models trained with rmsprop using momentum=0.9. It was/is a default setting for the TF Slim CNN training scripts, and also many of the TF TPU example scripts. So models like EfficientNet, MNasNet, MobileNetV3/V2, NASNet, PNASNet, AmoebaNet, InceptionV2/V3/V4, etc all trained with momentum=0.9.

from optax.

mtthss avatar mtthss commented on August 18, 2024

I think the PR can include both

from optax.

rwightman avatar rwightman commented on August 18, 2024

@mtthss Since I have this already for my testing, pre PR review. Aside from my formatting changes which will not be included in PR, any issues with

initial_scale changes for rms/stddev:
https://github.com/rwightman/efficientnet-jax/blob/optax/jeffnet/common/optim/rmsprop.py#L8-L64

rmsprop alias update:
https://github.com/rwightman/efficientnet-jax/blob/optax/jeffnet/common/optim/rmsprop.py#L67-L86

from optax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.