Code Monkey home page Code Monkey logo

jax-resnet's People

Contributors

n2cholas avatar sauravmaheshkar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

jax-resnet's Issues

WideResNet-28-10

Hi there, this repo is very useful, but could you add WideResNet-28-10? It is highlighted in the official repo and is the best-performing model on CIFAR-10 in RobustBench, so I believe many people will find it helpful.

Initialization is incorrect

Currently, all layers use the default Flax initialization. However, each paper uses a different strategy:

  1. ResNet, WideResNet, ResNeXt use Kaiming Normal
  2. ResNet-D uses Xavier Uniform
  3. ResNeSt say they use Kaiming Normal, but the code uses the PyTorch default which is Kaiming Uniform with a=sqrt(5).

There are a few options going forward:

  1. Set all the models to use Kaiming [Normal or Uniform], which has been shown to work best with ReLU activations. With this decision, we'll probably deviate from the torch default gain (which is for LeakyReLU) to a gain that is suitable to vanilla ReLU.
  2. Set all the models to the initialisation provided in their respective papers.
  3. Provide no default, force users to select one, but provide suggestions for suitable candidates in the docstring.

Transfer Learning API

Hey @n2cholas !

I was wondering how to properly do transfer learning, maybe this feature is not implemented yet but is it possible to select the second to last layer? More generally, can you select other intermediate layers?

I want to make an example of doing Transfer Learning in Elegy and these pre-trained models look perfect for the task.

Structure Difference between PyTorch ResNet and JAX resnet (at layer 4)

Hello Nicholas, while using pretrained RESNET(101)
I am comparing the output size of RESNET model in PyTorch after layer no. 4 (rendering the output before the avg pooling there)
after running it to an input batch size[1, 224, 224, 3]
It was torch.Size ([1, 2048, 28, 28]).
However, when I tried to render the output in your RESNET model JAX/FLAX (I have removed these 2 commented lines in RESNET function to get output before the avg pooling (layer4 equivalent to PyTorch)

def ResNet(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
   
    layers = [stem_cls(), pool_fn] 

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))
 #------------------------------------------------------------------------------
 #  layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
 # layers.append(nn.Dense(n_classes))
 #------------------------------------------------------------------------------
    return Sequential(layers)

It has a different output shape (for the same size of inp_batch(1, 224, 224, 3)) :

RESNET100, variables = pretrained_resnet(101)
RESNET = RESNET100()
model_out=RESNET.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False) 
print("pretrained resnet100 size:", jax.tree_map(lambda x: x.shape, model_out))

pretrained resnet100 size:--> (1, 7, 7, 2048)
So, what's happened at this stage in ResNet layers structure?
Kindly reply, if you have any explanation or recommendations.

Training state of ResNet coupled with mutable batch_stats collection

Hey @n2cholas!

This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.

mutable = self.is_mutable_collection('batch_stats')
x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x)

To solve this you have to:

  • Accept a use_running_average (or equivalent) argument in ConvBlock.__call__ and pass it to norm_cls.
  • Refactor ResNet to be a custom Module (instead of Sequential) so you also accept this in __call__ and pass it around to the relevant submodules that expect it.

Some repos use a single train flag to determine the state of both BatchNorm and Dropout.

Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.