google-deepmind / optax Goto Github PK
View Code? Open in Web Editor NEWOptax is a gradient processing and optimization library for JAX.
Home Page: https://optax.readthedocs.io
License: Apache License 2.0
Optax is a gradient processing and optimization library for JAX.
Home Page: https://optax.readthedocs.io
License: Apache License 2.0
__all__
in the __init__.py
overrides which functions are imported when a user does from optax import *
. It seems like every function in that file should be exposed through a wildcard import, so there is no need for the __all__
. Besides being redundant, having it creates opportunities for bugs: right now, many of the functions (e.g. maybe_update
, keep_params_nonnegative
) are imported but not exposed in __all__
. I believe it should be removed, and would be happy to create a PR if that makes sense.
See paper:
https://arxiv.org/abs/1901.11150
Hey, I am a user of optix, it seems optax is going to be the future of the library but I don't see any documentation. Is it too early to use this library?
When using clip_by_global_norm
on gradients of complex parameters, it seems we need to change jnp.square(x)
to x.conj() * x
in the function global_norm
in _src/linear_algebra.py
.
How is the current status of complex number support in Optax? I'm using neural networks in quantum physics, and I'd be happy to help JAX community to enhance complex number support. I guess I'll encounter more problems about it and I'll report them then.
Currently, to apply different optimizers to different sets of parameters, you need to construct multiple masks and chain them. For example, to optionally apply weight decay to some parameters and not to others:
tx = optax.chain(
optax.masked(optax.adamw(0.01, weight_decay=0.0),
mask=partial(jax.tree_map, lambda p: p.ndim == 1)),
optax.masked(optax.adamw(0.01, weight_decay=1e-3),
mask=partial(jax.tree_map, lambda p: p.ndim != 1)),
)
However, this has a few issues:
I'd like to propose a multi_transform
(or multi_transformation
, as always naming is hard) that generalizes optax.masked
:
def multi_transform(transforms: Sequence[GradientTransformation],
partition: Union[PyTree, Callable[[base.Params], PyTree]]):
...
Where partition
is a pytree where the leaves contain an index (or a function that returns such a pytree given the parameters). The index corresponds to a GradientTransformation
in transforms
.
With this, our above example would be:
tx = optax.multi_transform(
[optax.adamw(0.01, weight_decay=1e-3), optax.adamw(0.01, weight_decay=0.0)],
partition=partial(jax.tree_map, lambda p: int(p.ndim == 1))) # 0 for weight decay, 1 for no weight decay
This solves our above problems:
Minor details:
partition
could just be partition_fn: Callable[[base.Params], PyTree]
(I left it as partition
to mirror the signature of masked
).partition
could have a different name, perhaps indices
.cc: @mtthss @jheek @andsteing
As surfaced by PR #165 the tests for the lookahead optimizer do not check whether the fast parameters are passed to the fast optimizer.
The optax.trace()
gradient transformation already has a accumulator_dtype
argument, but that is not yet exposed in the optax.sgd()
alias that is commonly used for creating a SGD optimizer with momentum.
Similarly, it would be nice to have a dtype
argument for the accumulator(s) with optax.adam()
, which would allow to use Optax for replicating papers like Scaling Vision Transformers (Zhai et al, 2021).
I understand that this seems to be a Colab notebook specific error. If not appropriate to raise issue here, would be happy to raise it elsewhere. :)
ImportError Traceback (most recent call last)
<ipython-input-9-72cd76e3a907> in <module>()
4 from jax.experimental import maps
5 import numpy as np
----> 6 import optax
7 import transformers
8
6 frames
/usr/local/lib/python3.7/dist-packages/optax/__init__.py in <module>()
16 """Optax: composable gradient processing and optimization, in JAX."""
17
---> 18 from optax._src.alias import adabelief
19 from optax._src.alias import adagrad
20 from optax._src.alias import adam
/usr/local/lib/python3.7/dist-packages/optax/_src/alias.py in <module>()
20 import jax.numpy as jnp
21
---> 22 from optax._src import combine
23 from optax._src import privacy
24 from optax._src import schedule
/usr/local/lib/python3.7/dist-packages/optax/_src/combine.py in <module>()
16 """Flexibly compose gradient transformations."""
17
---> 18 from optax._src import transform
19 GradientTransformation = transform.GradientTransformation
20
/usr/local/lib/python3.7/dist-packages/optax/_src/transform.py in <module>()
18 from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
19
---> 20 import chex
21 import jax
22 import jax.numpy as jnp
/usr/local/lib/python3.7/dist-packages/chex/__init__.py in <module>()
15 """Chex: Testing made fun, in JAX!"""
16
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_gt
19 from chex._src.asserts import assert_devices_available
/usr/local/lib/python3.7/dist-packages/chex/_src/asserts.py in <module>()
29 import jax
30 import jax.numpy as jnp
---> 31 import jax.test_util as jax_test
32 import numpy as np
33 import tree as dm_tree
/usr/local/lib/python3.7/dist-packages/jax/test_util.py in <module>()
33 from . import dtypes as _dtypes
34 from . import lax
---> 35 from .config import flags, bool_env, config
36 from ._src.util import partial, prod
37 from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
ImportError: cannot import name 'flags' from 'jax.config' (/usr/local/lib/python3.7/dist-packages/jax/config.py)
Hi,
in the learning rate schedule list it may be useful to add a "reduce on plateau" scheduler which looks during a certain number of epochs if the loss is decreasing, and if not divide the learning rate by a certain amount. (possibly clipped by a minimal value).
In Pytorch I experience that in turns out to be useful for some of my use-cases of optimizing CNN.
In PyTorch, it is possible to have a single optimizer for different nn.Module
sets of parameters, and there are various ways to combine different modules' parameters.
For example, from https://github.com/altosaar/variational-autoencoder/blob/dfb452b5421e9e5b97315c6420b8766ac86f3f4f/train_variational_autoencoder_pytorch.py#L216:
optimizer = torch.optim.RMSprop( list(model.parameters()) + list(variational.parameters()), lr=cfg.learning_rate, centered=True, )What is the equivalent in optax? Is it chaining optimizers, or is such functionality not supported at this time, requiring different instances of optax optimizers, one per haiku.Module
?
Thanks so much!
Is there an equivalent to flax.optim.WeightNorm
? As flax.optim
is effectively deprecated in favor of optax, I would like to see it implemented in optax.
Hi there,
Is it possible to set the learning rate manually? e.g.
# Setup optimiser
opt_init, opt_update = optax.adam(learning_rate=1e-3)
opt_state = opt_init(params)
# Train
for epoch_num in range(10):
# Compute gradients.
grads = jax.grad(loss_function)(params, data)
# Transform the gradients using the optimiser.
updates, opt_state = opt_update(grads, opt_state)
# Update parameters.
params = optax.apply_updates(params, updates)
# *** MY IDEA/INTENTION ***:
if epoch_num == 5:
opt_state.learning_rate = 1e-4 # does something like this exist?
Many thanks for any help, and for this fantastic lib! :)
Currently, optax.sgd
and optax.noisy_sgd
unconditionally create momentum variables for the parameters, since both rely on trace
. For optax.noisy_sgd
, this is unnecessary since decay
is always 0
. For optax.sgd
, this is unexpected since momentum=0
by default (and can be wasteful for large models).
optax.noisy_sgd
should only require _scale_by_learning_rate
(with a negation). optax.sgd
could conditionally add trace
if momentum > 0
.
Below are the lines of code I'm referring to:
And here's where trace automatically creates it's state:
This is a very exciting project! I was just considering using flax.optim when I found optax, and I love the elegant combine.chain
design of the varias optimizer aliases. Very cool!
I'd like to consider learning as an iterated function of the parameters, which itself depends on meta-parameters (e.g. learning rate). Then, I can use the fixed point theorem to calculate the gradient of the loss on a batch with respect to the metaparameters.
Unfortunately, optax's GradientTransformations
are implemented using functions that close over values, which means that these values cannot be JAX tracers. From my understanding, you cannot take the derivative with respect to the step_size
if the step size is a closed-over-value.
I know this might be a serious change, but would it be possible, instead of having:
def scale(step_size: float) -> GradientTransformation:
...
return GradientTransformation(init_fn, update_fn)
To implement the abstract methods init_fn
and update_fn
in an inherited class:
class Scale(GradientTransformation):
def __init__(self, step_size):
...
This design would allow:
step_size
),scale.step_size
is available in the object oriented approach) for debugging,scale(1e-3)
twice, you get a different object each time, and these objects will not compare equal. If these objects are passed to a jitted function, the function will be recompiled even though the objects would normally be equal.Currently, the Markdown-flavored links (inside Python files) in Optax API docs (from Python doc strings) appear to not render well on the ReadTheDocs site.
For example, this (source: https://github.com/deepmind/optax/tree/master/optax/_src/alias.py#L322#L374):
def rmsprop(
...
) -> base.GradientTransformation:
"""A flexible RMSProp optimiser.
...
References:
[Tieleman and Hinton, 2012](
www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
[Graves, 2013](https://arxiv.org/abs/1308.0850)
...
translates into this (source: https://optax.readthedocs.io/en/latest/api.html#rmsprop):
Potential solutions:
`Text <URL>`_
(Source: https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html#hyperlinks)
[text](URL)
) to keep things simple, like in Haiku docs, since most links are from arXiv and they are quite short:- [Graves, 2013](https://arxiv.org/abs/1308.0850)
+ Graves, 2013 https://arxiv.org/abs/1308.0850
which is similar to https://dm-haiku.readthedocs.io/en/latest/api.html#id1:
class Linear(hk.Module):
"""Linear module."""
def __init__(
...
):
"""Constructs the Linear module.
Args:
...
w_init: Optional initializer for weights. By default, uses random values
from truncated normal, with stddev ``1 / sqrt(fan_in)``. See
https://arxiv.org/abs/1502.03167v3.
LMKWYT @mtthss
I can open a PR ๐
Hi,
The optax.dpsgd
optimizer is special in that it takes per-example gradients as input, and takes care of aggregating them.
There is currently no documentation of how this implementation is supposed to be used with multiple devices (using pmap
), and there might be a few options to do it.
@n2cholas is this something you have thought about?
It would be convenient to support specifying an end_value
for exponential decay. So, the exponential_decay
would require init_value
and transition_steps
and one of (end_value
or decay_rate
). Here is an example of what I'm proposing.
I would be happy to contribute this, if you think it is suitable.
Hi,
I'm the lead developer of NetKet, an established machine learning / quantum physics package.
We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version.
Since many physicists seem to use anaconda, we would also like to update our conda recipe.
However, since we depend on Optax, we would need Optax to have a Conda recipe.
Is that something you'd consider? I already contributed the work for Chex and have the Optax recipe ready to go.
I am willing to volunteer the recipe by myself. I just need another member of the flax-team to be listed as maintainer of the recipe.
The recipe itself will be low maintenance, as it will pick-up pypi releases automatically and release new versions unless errors arise.
cc @hbq1
This is due to a bug with integer exponentiation, leading to a divide by zero during bias correction on certain iteration multiples (e.g., 64 for b1=0
). This issue is the closest I could get to the root cause. In the interim while the underlying issue is fixed, it could be guarded against in optax
by casting the decay to a float in the bias correction helper.
Add documentation for maybe_update
as currenlty the relevant reference page on github is empty.
A lot of papers using RMSProp have hparams that were tuned with the original Tensorflow impl.
The Optax impl is missing the momentum option and initializes the rms value differently.
TF1 RMSProp (https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/training/rmsprop.py#L126)
Keras RMSProp (https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/optimizer_v2/rmsprop.py#L35-L299)
I've spent time replicating results of papers like EfficientNet (and related) in PyTorch and ended up using my own RMSProp impl that matches the TF1 variant (the PyTorch one does not either).
This issue provides a home for work on adding informative examples of using optax (for instance examples reproducing results from interesting optimisation papers).
Reach out on this issue if you are interested and/or have suggestions.
Right now, there's some boilerplate for defining simple gradient transformations. If a user wants to implement custom weight decay, clipping, constraints, etc without any state, they still have to define a nested function with an init
that does nothing, and handle the empty state.
I think it'd be convenient to provide a stateless
transformation that accepts a function to apply to the updates and params. We can also do the jax.tree_multimap
for the user by default.
weight_decay = optax.stateless(lambda g, p: g + 0.1 * p)
If the user wants to define a function that does the tree_multimap themselves:
def my_function(updates, params):
return jax.tree_multimap(lambda g, p: ..., updates, params)
optim = optax.stateless(my_function, on_leaves=False)
In my view, this is a very clean way to implement simple stateless transformations. I think the on_leaves
argument could use a better name, though. I'd be happy to implement this if it sounds reasonable.
Currently, optax.masked
accepts a boolean mask that has the same structure as (or is a prefix of) the parameters pytree. This breaks the pattern of only requiring the parameters during init
and not before. This issue is to discuss the possibility of changing the mask
argument to mask_fn
, which would be a function that takes a parameter pytree as input and returns a mask that has the same/prefix structure as the params.
One clear advantage of the mask_fn
approach is users can define an optimizer independent of the model (as is the case with every other transformation in optax).
The current use pattern is still possible by passing in mask=lambda _: mask
for a premade mask
if desired.
This was originally proposed by @jheek in the New Optimizer API for Flax discussion.
It would allow to disable optimization of some parameters.
See paper:
log-cosh is a doubly differentiable alternative to the huber loss. A naive implementation is prone to overflow (since cosh has an e^x term), so I think it'd be a useful addition to the library. Plus, it's implemented in other libraries, such as TensorFlow.
If this sounds like a relevant addition, I'd be happy to contribute it!
Hi, i use the optax to implement the following convent to classify the Mnist dataset. I wonder why it is not learning?
import itertools
import time
import haiku as hk
import jax
import jax.numpy as jnp
import numpy.random as npr
import optax
from examples import datasets
from jax import grad, jit, random
from jax.experimental import optimizers, stax
from jax.experimental.stax import (
Dense,
Flatten,
GeneralConv,
LogSoftmax,
Relu,
elementwise,
)
def net_fn(x) -> jnp.ndarray:
"""Standard LeNet-300-100 MLP network."""
mlp = hk.Sequential(
[
hk.Conv2D(output_channels=16, kernel_shape=[5, 5], padding="SAME"),
jax.nn.relu,
hk.MaxPool(window_shape=[2, 2], strides=[2, 2], padding="VALID"),
hk.Conv2D(output_channels=32, kernel_shape=[5, 5], padding="SAME"),
jax.nn.relu,
hk.MaxPool(window_shape=[2, 2], strides=[2, 2], padding="VALID"),
hk.Flatten(),
hk.Linear(10),
]
)
return mlp(x)
net = hk.without_apply_rng(hk.transform(net_fn))
def loss(params, batch):
inputs, targets = batch
preds = net.apply(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(net.apply(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size : (i + 1) * batch_size]
yield train_images[batch_idx].reshape(-1, 1, 28, 28), train_labels[
batch_idx
]
batches = data_stream()
optimizer = optax.chain(
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), optax.scale(-step_size)
)
@jit
def update(params, optimizer_state, batch):
grads = grad(loss)(params, batch)
optim_update, optimizer_state = optimizer.update(grads, optimizer_state, params)
params = optax.apply_updates(params, optim_update)
return params, optimizer_state
params = net.init(jax.random.PRNGKey(42), next(batches)[0])
# params = init_params
optimizer_state = optimizer.init(params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params, optimizer_state = update(params, optimizer_state, next(batches))
epoch_time = time.time() - start_time
train_acc = accuracy(
params, (train_images.reshape(-1, 1, 28, 28), train_labels)
)
test_acc = accuracy(params, (test_images.reshape(-1, 1, 28, 28), test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
I got the results of the following:
Starting training...
Epoch 0 in 1.73 sec
Training set accuracy 0.13341666758060455
Test set accuracy 0.14079999923706055
Epoch 1 in 0.34 sec
Training set accuracy 0.09730000048875809
Test set accuracy 0.10090000182390213
Epoch 2 in 0.32 sec
Training set accuracy 0.09966666996479034
Test set accuracy 0.10300000756978989
Epoch 3 in 0.33 sec
Training set accuracy 0.12701666355133057
Test set accuracy 0.12640000879764557
Epoch 4 in 0.34 sec
Training set accuracy 0.13825000822544098
Test set accuracy 0.13690000772476196
Epoch 5 in 0.37 sec
Training set accuracy 0.11961666494607925
Test set accuracy 0.1193000078201294
Epoch 6 in 0.34 sec
Training set accuracy 0.16324999928474426
Test set accuracy 0.1599000096321106
Epoch 7 in 0.33 sec
Training set accuracy 0.1628333330154419
Test set accuracy 0.16090001165866852
Epoch 8 in 0.38 sec
Training set accuracy 0.16438333690166473
Test set accuracy 0.1623000055551529
Epoch 9 in 0.33 sec
Training set accuracy 0.10546667128801346
Test set accuracy 0.10450000315904617
Epoch 10 in 0.35 sec
Training set accuracy 0.1001666709780693
Test set accuracy 0.09920000284910202
Epoch 11 in 0.34 sec
Training set accuracy 0.09881667047739029
Test set accuracy 0.09830000251531601
Epoch 12 in 0.32 sec
Training set accuracy 0.09871666878461838
Test set accuracy 0.09800000488758087
Epoch 13 in 0.33 sec
Training set accuracy 0.09881667047739029
Test set accuracy 0.09830000251531601
Epoch 14 in 0.34 sec
Training set accuracy 0.10068333148956299
Test set accuracy 0.10000000149011612
Epoch 15 in 0.35 sec
It would be nice if someone know the reason!
Now that Jax supports dataclasses as PyTrees, would it be possible to switch to using them instead of namedtuple? The benefits are explained here.
The biggest benefit would be preventing unnecessary recompilation. The current Optax code uses closures, which will cause Jax to unnecessarily recompile a jitted function that accepts a GradientTransformation
. (The closures are different objects that hash differently, which means that changing the parameters to the GradientTransformation
must cause the jitted function to recompile.)
A dataclass version of Optax would look something like this.
I am happy to do submit a pull request if this change is okay.
OneCycle learning rate from this paper: https://arxiv.org/abs/1708.07120.
The paper describes a 3-stage piecewise linear schedule. PyTorch and FastAI implement a 2 stage variant with cosine annealing (which apparently works better in practice).
I would like to contribute both to this library, if you think they are suitable additions.
optax.softmax_cross_entropy accepts inputs of shape [..., num_classes] and returns a result of shape [...]. This behavior should be clearly documented but is not.
Likewise for other loss functions.
Sorry if this already exists in optax as a feature, but how would you go about making a multi-optimizer (similar to Flax optim) that could use different learning rates for different parts of a network?
Specifically, I'm have a full model in Haiku with one learning rate for most of the parameters, but different learning rates for specific subsets. I can partition the params appropriately and create a separate optimizer for each subset, but ideally I'd like to maintain the simplicity of a single set of params and optimizer_state. Is there a common approach to this?
Since the dataset is turned into a list, the same batch order and batches are used for each epoch:
Training loop:
The dataset should not be turned into a list at all, but there is a significant performance drop when the tf.data.Dataset
is used directly (~4s/epoch with the tf.data.Dataset
, ~0.7 seconds with the list). I was not able to improve this, so would appreciate if someone with tf.data
expertise could take a look. Thanks!
I want to use rectification from radam with adabelief optimizer, as suggested by the author of adabelief. How can I do that?
The Readme currently does not mention wrappers and functionality like MultiSteps
. Add documentation of this to the readme to make it easier to discover these features.
Currently, additive_weight_decay
will decay all the parameters. Jia et al. 2018 show it is beneficial not to decay the bias parameters (and only decay the weights). Many training examples implement this too, such as Flax's Imagenet example.
One way to support this is to add a decay_bias: bool
argument to additive_weight_decay
, then within the decay update:
updates = jax.tree_multimap(
lambda g, p: g + weight_decay * p if p.ndim > 1 or decay_bias else g,
updates, params)
As far as I'm aware for common NN layers, the bias always has one dimension, so checking if p.ndim > 1
is sufficient (please correct me if I'm wrong).
I'd be happy to contribute this if it sounds reasonable.
EDIT: I realized batch norm scale parameters also have only one dimension, so this filter would wrongly include those. However, many training pipelines do not regularize batchnorm scale/bias (e.g. the reference imagenet implementation from mlperf).
See paper:
https://arxiv.org/abs/1907.08610
For details
See paper:
https://arxiv.org/abs/2010.07468
This should be a trivial modification of the existing scale_by_adam
transform.
The MultiSteps wrapper https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/optax/_src/wrappers.py#L179 for gradient accumulation does not feed parameters to the wrapped optimizer. These are needed for e.g. weight decay https://github.com/deepmind/optax/blob/30947fbc5743adc3e997c6242fa3775834862a74/optax/_src/transform.py#L547 which I think is used in almost every optimizer chain.
As far as I can tell every transformation accepts params, so is there any reason not to pipe them through?
As mentioned in #144 , we would like to help enhance the support for complex numbers in the JAX community. The detailed proposal document is here: https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
In brief, we need to properly implement the norm of complex variables in the optimizers. We need to decide whether to implement the complex norm, the split real norm, or both of them.
Although there is also a comment zone below the gist, I would like to keep the discussion in this issue thread. Feel free to leave your comments!
Hi! I'm trying to extract the learning rate from an optax optimizer directly for logging to Tensorboard.
I know I could get it from my learning rate schedule object instead by passing in step
, but we've previously run into situations where the optimizer step # and expected step # went out of sync (our fault, not optax's), so to be safe we'd like to get it directly from the optimizer object. In Tensorflow you can do self._optimizer._get_hyper('learning_rate')
to access it since it gets logged via _set_hyper
. Is there an easy way to do a similar thing in optax?
Step variables schedules have type int
instead of Union[float ,int]
.
This is because the schedules are used for controlling learning rate schedules from integer step counts.
Users have requested that we make it possible to use schedules in other contexts, where the input would no longer be an integer.
One option here would be to admit more general types in the typings, and then remove specific references to steps in the schedules.
I don't understand what optax.mask
does.
I would expect that the masked optimizer
optax.mask(optax.sgd(0.1), {"Dense": True, "bias": False})
would only apply the optimisation to sub-leafs of Dense
and not optimise sub-leaves of bias
.
Which means that the masked gradient should match the sgd one for Dense
and be zero for bias
.
However it seems to me that the masked updates are correct for sub-leafs of Dense
(so where the mask is True, but they are the identity where the mask is False.
Is this intended behaviour? it seems rather strange to me.
I was trying to update only a subsets of the weights of my model but this was not working
MWE:
import jax.numpy as jnp
import jax
import optax
pars = {"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)}
grad = jax.tree_map(jnp.ones_like, pars)
op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})
op_state = op.init(pars)
masked_updates, new_op_state = op.update(grad, op_state, pars)
>>> masked_updates
{'Dense': {'bias': DeviceArray([-0.1, -0.1, -0.1], dtype=float32), 'kernel': DeviceArray([[-0.1, -0.1, -0.1],
[-0.1, -0.1, -0.1]], dtype=float32)}, 'bias': DeviceArray([1., 1.], dtype=float32)}
Is it possible to have schedules for hyper-parameters such as momentum?
For example, onecycle momentum from this paper: https://arxiv.org/abs/1708.07120. There are also benefits to using schedules for parameters like epsilon (as a trust-region / damping parameter), highlighted by this blog post: http://zna.do/epsilon.
The mask must be returned/given as a frozenDict, which is annoying.
I'm not sure this is really an optax bug... but could something be done to alleviate this?
This also shows up in multi_transform
, where the fix is less obvious because it internally builds a dict, therefore the only way to make it work is to unfreeze the params before giving it to the optimiser, which is... inconvenient.
>>> import jax.numpy as jnp
>>> import jax
>>> import optax
>>> from flax.core import freeze, unfreeze
>>>
>>> pars = freeze({"Dense": {"kernel": jnp.zeros((2,3)), "bias": jnp.zeros((3))}, "bias":jnp.zeros(2)})
>>> op = optax.masked(optax.sgd(0.1), {"Dense": True, "bias": False})
>>> op.init(pars)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/filippovicentini/Documents/pythonenvs/netket_env/lib64/python3.8/site-packages/optax/_src/wrappers.py", line 311, in init_fn
flat_params = treedef.flatten_up_to(params)
ValueError: Expected dict, got FrozenDict({
Dense: {
kernel: DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32),
bias: DeviceArray([0., 0., 0.], dtype=float32),
},
bias: DeviceArray([0., 0.], dtype=float32),
}).
Differentially Private SGD (https://cseweb.ucsd.edu/~kamalika/pubs/scs13.pdf) is an important algorithm in private machine learning. Essentially, it is SGD except you clip and add Gaussian noise to per-example gradients before averaging across the batch. I think this would be a useful addition to Optax.
The implementation could be based on the example in the JAX repo: https://github.com/google/jax/blob/master/examples/differentially_private_sgd.py
The usage would be slightly different from other transforms, since it requires per-example gradients as inputs. It can still be composed with other transforms as long as it is the first one in the chain. Alternatively, we can expose a stand-alone utility function that does the clipping/noise/aggregation that the user could then pass to a GradientTransform. I think the former option (making it a transform) is more convenient since this algorithm would have some state (the RNG key).
I'd be happy to work on this if it seems like a good addition.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.