Code Monkey home page Code Monkey logo

adan's Introduction

Tensorflow Adan

Unofficial implementation of Adan optimizer.

This implementation differs from the official pytorch implementation. The main difference is that gradient parameters aren't updated for categorical values which aren't present in the current batch. It's especially important for tasks when the batch doesn't contain all possible categorical values.

See "Test sparse - a lot of categories" in notebooks/test_adan.ipynb for illustation.

See the paper for details - Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models.

See official pytorch implementation - Adan.

Install

pip install adan-tensorflow

Usage example

from tf_adan.adan import Adan

model.compile(
    optimizer=Adan(),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

See notebooks/example.ipynb for an example.

Running tests

To test the correctness of the implementation, we're running official pytorch implementation and tensorflow implementation on the same data. If the hparams of the optimizers are the same (lr, betas, etc) and initial data is the same, loss history and weights after optimization must be the same too.

  1. Build docker image
docker build -t latest .
docker run -p 8888:8888 -v $(pwd):/work latest jupyter notebook --ip 0.0.0.0 --port=8888 --allow-root
  1. Run notebooks/test_adan.ipynb

adan's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

dl-cnn

adan's Issues

Производительность, память и нюансы sparse-обновлений

Вдохновившись Вашей реализацией решил "позаимствовать" оптимизатор для TF себе в копилку и попутно порефакторить. Вот что нашлось:

Потребление памяти
SGD без момента вообще не использует доп. переменные и как следствие не потребляет дополнительной памяти. С моментом = память x2
Adam если правильно помню хранит 3 слота = память x3
Adan только для dense-реалзиации потребляет x4 памяти, а в Вашей реализации x5 из-за вот этого https://github.com/DenisVorotyntsev/Adan/blob/main/tf_adan/adan.py#L50

Хорошего способа совсем убрать счетчик апдейтов я не нашел, но можно значительно урезать объем этого 5го слота (сценарий когда будут делать срезы внутри канала имхо маловероятен) https://github.com/shkarupa-alex/tfmiss/blob/develop/tfmiss/keras/optimizers/adan.py#L51

Точность sparse-обновлений #1
При расчете каждого обновления Adan использует текущий номер шага в расчете bias_correction_* https://github.com/DenisVorotyntsev/Adan/blob/main/tf_adan/adan.py#L74
При sparse-обновлениях нужно использовать текущий номер апдейта для срезов, а не глобальный номер шага (можно проверить поэлементно пропуская те шаги в которых нет индекса этого элемента)

Чтобы получить точное соответствие dense-ветке пришлось вынести bias_correction_* в каждую из веток и считать их в dense-ветке на основе глобального шага (как в Вашей реализации) а вот в sparse-ветке на основе того самого счетчика обновлений https://github.com/shkarupa-alex/tfmiss/blob/develop/tfmiss/keras/optimizers/adan.py#L148

Точность sparse-обновлений #2
_resource_scatter_update возвращает всю переменную, а не только текущий срез
В Вашей реализации это приводит к обновлению всей переменной каждый раз, а не только к обновлению текущего среза (что во-первых менее производительно, во вторых кажется несет ошибку).
Т.е. если какой-то категориальной переменной в срезе не было она все равно обновится по данным предыдущих итераций.
Пришлось немного переструктурировать код и все sparse-обновления делать после расчетов на текущих срезах

Прочее
Если вот так брать скорость обучения https://github.com/DenisVorotyntsev/Adan/blob/main/tf_adan/adan.py#L59 подозреваю что не будут работать расписания lr
Кажется правильнее брать lr_t который появляется после super()._prepare_local(...)


По графикам не все однозначно (sparse-часть блокнота).
Моя реализация оказывается ближе к оригинальной почти везде кроме 1го графика где ведет себя лучше и стабильнее чем оригинальная и Ваша.
Снимок экрана 2022-10-04 в 13 30 56

_set_hyper() was obsoleted in latest version of tf

        self._set_hyper("learning_rate", learning_rate)
        self._set_hyper("beta_1", beta_1)
        self._set_hyper("beta_2", beta_2)
        self._set_hyper("beta_3", beta_3)
        self._set_hyper("epsilon", epsilon)
        self._set_hyper("weight_decay", weight_decay)

In init() fucntion, the code used the function _set_hyper(), which was obsoleted .

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.