n2cholas / jax-resnet Goto Github PK
View Code? Open in Web Editor NEWImplementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Home Page: https://pypi.org/project/jax-resnet/
License: MIT License
Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
Home Page: https://pypi.org/project/jax-resnet/
License: MIT License
Hi there, this repo is very useful, but could you add WideResNet-28-10? It is highlighted in the official repo and is the best-performing model on CIFAR-10 in RobustBench, so I believe many people will find it helpful.
Currently, all layers use the default Flax initialization. However, each paper uses a different strategy:
There are a few options going forward:
Hey @n2cholas !
I was wondering how to properly do transfer learning, maybe this feature is not implemented yet but is it possible to select the second to last layer? More generally, can you select other intermediate layers?
I want to make an example of doing Transfer Learning in Elegy and these pre-trained models look perfect for the task.
Hello Nicholas, while using pretrained RESNET(101)
I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there)
after running it to an input batch size[1, 224, 224, 3]
It was torch.Size ([1, 2048, 28, 28]).
However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)
def ResNet(
block_cls: ModuleDef,
*,
stage_sizes: Sequence[int],
n_classes: int,
hidden_sizes: Sequence[int] = (64, 128, 256, 512),
conv_cls: ModuleDef = nn.Conv,
norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
conv_block_cls: ModuleDef = ConvBlock,
stem_cls: ModuleDef = ResNetStem,
pool_fn: Callable = partial(nn.max_pool,
window_shape=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1))),
) -> Sequential:
conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
layers = [stem_cls(), pool_fn]
for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
for b in range(n_blocks):
strides = (1, 1) if i == 0 or b != 0 else (2, 2)
layers.append(block_cls(n_hidden=hsize, strides=strides))
#------------------------------------------------------------------------------
# layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool
# layers.append(nn.Dense(n_classes))
#------------------------------------------------------------------------------
return Sequential(layers)
It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :
RESNET100, variables = pretrained_resnet(101)
RESNET = RESNET100()
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False)
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))
pretrained resnet100 size:--> (1, 7, 7, 2048)
So, what's happened at this stage in ResNet layers structure?
Kindly reply, if you have any explanation or recommendations.
pretrained_resnest works.
Hey @n2cholas!
This is not an immediate issue but I was playing around with jax_resnet
and noticed that ConvBlock
decides if it should update it batch statistics or not depending on whether the batch_stats
collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet
inside a another module that by chance also uses BatchNorm
and you want to train the other module but freeze ResNet
, it is not clear how you would do this.
jax-resnet/jax_resnet/common.py
Lines 43 to 44 in 5b00735
To solve this you have to:
use_running_average
(or equivalent) argument in ConvBlock.__call__
and pass it to norm_cls
.ResNet
to be a custom Module (instead of Sequential
) so you also accept this in __call__
and pass it around to the relevant submodules that expect it.Some repos use a single train
flag to determine the state of both BatchNorm and Dropout.
Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.