bojone / accum_optimizer_for_keras Goto Github PK
View Code? Open in Web Editor NEWwrapping a keras optimizer to implement gradient accumulation
wrapping a keras optimizer to implement gradient accumulation
Hi Jianlin Su,
Thank you for this accumulator. I suspect I don't understand something.
I'm expecting similar results between these 2 experiments:
1) with your accumulator (I divide batch size by 2, accumulate across 2 batch)
img_gen = generator(minibatch_size = minibatch_size/2)
optimizer = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
opt = accum_optimizer.AccumOptimizer(optimizer, 2)
model.compile(loss=keras.losses.binary_crossentropy, optimizer=opt, metrics=["accuracy"])
Epoch 1/100 loss: 0.6930
Epoch 2/100 loss: 0.6918
Epoch 3/100 loss: 0.6897
Epoch 4/100 loss: 0.6901
Epoch 5/100 loss: 0.6841
Epoch 1/100 loss: 0.6911 - acc: 0.5279 - val_loss: 0.6785 - val_acc: 0.5632
Epoch 2/100 loss: 0.6898 - acc: 0.5485 - val_loss: 0.6689 - val_acc: 0.5704
Epoch 3/100 loss: 0.6876 - acc: 0.5471 - val_loss: 0.6712 - val_acc: 0.6172
Epoch 4/100 loss: 0.6900 - acc: 0.5418 - val_loss: 0.6936 - val_acc: 0.4876
Epoch 5/100 loss: 0.6883 - acc: 0.5476 - val_loss: 0.6661 - val_acc: 0.6184
2) without your accumulator (nominal batch size, no accumulation)
img_gen = generator(minibatch_size = minibatch_size)
optimizer = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
model.compile(loss=keras.losses.binary_crossentropy, optimizer=optimizer, metrics=["accuracy"])
Epoch 1/100 loss: 0.6039
Epoch 2/100 loss: 0.4983
Epoch 3/100 loss: 0.4373
Epoch 4/100 loss: 0.4184
Epoch 5/100 loss: 0.3928
Epoch 1/100 loss: 0.6640
Epoch 2/100 loss: 0.5702
Epoch 3/100 loss: 0.4759
Epoch 4/100 loss: 0.4389
Epoch 5/100 loss: 0.4094
Question:
If I divide the batch by half and accumulate over 2, should I not see similar results?
AttributeError Traceback (most recent call last)
~/Kit/Classification/Siamese-Prostate-MRI-Similarity-master/TripleSiamese.py in
159 siamese_net = Model(inputs=[left_inputt2w, right_inputt2w, left_inputhbv,
160 right_inputhbv, left_inputsag, right_inputsag], outputs=distance)
--> 161 adam = AccumOptimizer(Adam(0.00001), 9)
162 siamese_net.compile(loss=contrastive_loss, optimizer=adam, metrics=[accuracy])
163
~/Kit/Classification/Siamese-Prostate-MRI-Similarity-master/accum_optimizer.py in init(self, optimizer, steps_per_update, **kwargs)
37 for attr in self.optimizer.get_config():
38 if not hasattr(self, attr):
---> 39 value = getattr(self.optimizer, attr)
40 setattr(self, attr, value)
41 # 覆盖原有的获取梯度方法,指向累积梯度
~/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in getattribute(self, name)
676 if name in self._hyper:
677 return self._get_hyper(name)
--> 678 raise e
679
680 def setattr(self, name, value):
~/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py in getattribute(self, name)
666 """Overridden to support hyperparameter access."""
667 try:
--> 668 return super(OptimizerV2, self).getattribute(name)
669 except AttributeError as e:
670 # Needed to avoid infinite recursion with setattr.
AttributeError: 'Adam' object has no attribute 'name'
Horovod use from_config to get the parameters, then by these instance a new optimizer for each process
so this wrapper has wrong function about get_config
also use an exist optimizer may have some confuse question when training
Nice work! Will you release an implemention of tensorflow version?
When use keras embed layer, the grad is IndexedSlices object instead of Tensor
#self.updates.append(K.update(ag, K.switch(self.cond, g, ag + g)))
self.updates.append(K.update(ag, K.switch(self.cond, tensorflow.convert_to_tensor(g), ag + tensorflow.convert_to_tensor(g))))
I call the optimizer this way:
optimizer_critic = AccumOptimizer(Adam(lr=self.beta), 10)
model_critic.compile(optimizer=optimizer_critic, loss='mean_squared_error')
TypeError Traceback (most recent call last)
<ipython-input-15-6e2f40e506da> in <module>
3
4 env = Environment()
----> 5 agent = Agent(alpha=0.00002, beta=0.0001, input_dims=env.input_dim, n_action=env.n_action, load=True)
6
7 num_episodes = 2000
<ipython-input-13-059153b528ed> in __init__(self, alpha, beta, gamma, n_action, load, input_dims, layer_shared, layer_actor, layer_critic)
88 self.action_space = [i for i in range(n_action)]
89
---> 90 self.actor, self.critic, self.policy = self.build_actor_critic_network()
91
92
<ipython-input-13-059153b528ed> in build_actor_critic_network(self, load)
118 return K.sum(-log_likelihood * delta)
119
--> 120 optimizer_actor = AccumOptimizer(Adam(), 10)
121 optimizer_critic = AccumOptimizer(Adam(), 10)
122
<ipython-input-13-059153b528ed> in __init__(self, optimizer, steps_per_update, **kwargs)
30 class AccumOptimizer(Optimizer):
31 def __init__(self, optimizer, steps_per_update=1, **kwargs):
---> 32 super(AccumOptimizer, self).__init__(**kwargs)
33 self.optimizer = optimizer
34 with K.name_scope(self.__class__.__name__):
TypeError: __init__() missing 1 required positional argument: 'name'```
Keras version 2.2.4-tf
As of keras 2.3.0 self.lr
was renamed to self.learning_rate
(https://github.com/keras-team/keras/releases) so
self.learning_rate
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.