Code Monkey home page Code Monkey logo

residual-attention-network's Introduction

Residual Attention Networks

** The progress thus far are preliminary results. I am still in the process of rigorous validation of the efficiency of this deep learning model **

I came across this network while studying about Attention mechanisms and found the architecture really intriguing. After reading the paper, "Residual Attention Network for Image Classification", by Fei Wang et al., I put myself to task to implement it to understand the network more in-depth.

To give a brief summary of the paper, a Residual Attention Network (RAN) is a convolutional neural network that incorporates both attention mechanism and residual units, a component that utilizes skip-connections to jump over 2–3 layers with nonlinearities (e.g. ReLU in CNNs) and batch normalizations. It's prime feature is the attention module.

Figure: Full Architecture
Figure: Full Architecture

Attention Modules

The RAN is built by stacking Attention Modules, which generate attention-aware features that adaptively change as layers move deeper into the network.

Figure: Attention Module
Figure: Attention Module

The composition of the Attention Module includes two branches: the trunk branch and the mask branch.

  • Trunk Branch performs feature processing with Residual Units

  • Mask Branch uses bottom-up top-down structure softly weight output features with the goal of improving trunk branch features

    • Bottom-Up Step: collects global information of the whole image by downsampling (i.e. max pooling) the image
    • Top-Down Step: combines global information with original feature maps by upsampling (i.e. interpolation) to keep the output size the same as the input feature map

Once the actions are completed, the features extracted from the respective branches are combined together using the team's novel Attention Residual Learning formula. This is used to train very deep Residual Attention Networks so that it can be easily scaled up to hundreds of layers without a drop in performance. Thus, increasing Attention Modules leads to consistent performance improvement, as different types of attention are captured extensively.

Implementation

I've certainly glossed over a lot of details within the paper, so I'd definitely recommend reading that.

Once I gained a solid understanding of the RAN, I utilized Keras to put it to action. For my use case, I trained the model to classify cats vs dogs image data. For the results of this, you can find the notebook here .

To put it short, the network consistently improved despite being very deep. This cats vs dogs classification was a small example, however. I plan on utilizing this network in further studies to examine its power and report back to the world.

Acknowledgements

Paper: “Residual Attention Network for Image Classification” https://arxiv.org/pdf/1704.06904.pdf
Authors: Wang, Fei and Jiang, Mengqing and Qian, Chen and Yang, Shuo and Li, Cheng and Zhang, Honggang and Wang, Xiaogang and Tang, Xiaoou
Github: https://github.com/fwang91/residual-attention-network

Paper: "Deep Residual Learning for Image Recognition" https://arxiv.org/pdf/1512.03385.pdf
Paper: "Identity Mappings in Deep Residual Networks" https://arxiv.org/pdf/1603.05027.pdf
Authors: Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun

Github: https://github.com/qubvel/residual_attention_network

  • Though our final implementation is ultimately different, I was able to get a more solid understanding of the network thanks to the github user, qubvel. Thanks!

residual-attention-network's People

Contributors

deontaepharr avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.