Code Monkey home page Code Monkey logo

stand-alone-self-attention's Introduction

Implementing Stand-Alone Self-Attention in Vision Models using Pytorch (13 Jun 2019)

  • Stand-Alone Self-Attention in Vision Models paper
  • Author:
    • Prajit Ramachandran (Google Research, Brain Team)
    • Niki Parmar (Google Research, Brain Team)
    • Ashish Vaswani (Google Research, Brain Team)
    • Irwan Bello (Google Research, Brain Team)
    • Anselm Levskaya (Google Research, Brain Team)
    • Jonathon Shlens (Google Research, Brain Team)
  • Awesome :)

Method

  • Attention Layer

    • Equation 1:

      CodeCogsEqn (2)

  • Relative Position Embedding

    • The row and column offsets are associated with an embedding CodeCogsEqn (3) and CodeCogsEqn (4) respectively each with dimension CodeCogsEqn (5). The row and column offset embeddings are concatenated to form CodeCogsEqn (6). This spatial-relative attention is now defined as below equation.

    • Equation 2:

      CodeCogsEqn (7)

    • I refer to the following paper when implementing this part.

  1. Replacing Spatial Convolutions
    - A 2 × 2 average pooling with stride 2 operation follows the attention layer whenever spatial downsampling is required. - This work applies the transform on the ResNet family of architectures. The proposed transform swaps the 3 × 3 spatial convolution with a self-attention layer as defined in Equation 3.
  2. Replacing the Convolutional Stem
    - The initial layers of a CNN, sometimes referred to as the stem, play a critical role in learning local features such as edges, which later layers use to identify global objects. - The stem performs self-attention within each 4 × 4 spatial block of the original image, followed by batch normalization and a 4 × 4 max pool operation.

Experiments

Setup

  • Spatial extent: 7
  • Attention heads: 8
  • Layers:
    • ResNet 26: [1, 2, 4, 1]
    • ResNet 38: [2, 3, 5, 2]
    • ResNet 50: [3, 4, 6, 3]
Datasets Model Accuracy Parameters (My Model, Paper Model)
CIFAR-10 ResNet 26 90.94% 8.30M, -
CIFAR-10 Naive ResNet 26 94.29% 8.74M
CIFAR-10 ResNet 26 + stem 90.22% 8.30M, -
CIFAR-10 ResNet 38 (WORK IN PROCESS) 89.46% 12.1M, -
CIFAR-10 Naive ResNet 38 94.93% 15.0M
CIFAR-10 ResNet 50 (WORK IN PROCESS) 16.0M, -
IMAGENET ResNet 26 (WORK IN PROCESS) 10.3M, 10.3M
IMAGENET ResNet 38 (WORK IN PROCESS) 14.1M, 14.1M
IMAGENET ResNet 50 (WORK IN PROCESS) 18.0M, 18.0M

Usage

Requirements

  • torch==1.0.1

Todo

  • Experiments
  • IMAGENET
  • Review relative position embedding, attention stem
  • Code Refactoring

Reference

stand-alone-self-attention's People

Contributors

leaderj1001 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

stand-alone-self-attention's Issues

Question about einsum.

Hello. I was going through the attention implementation and had a question about the operation by the einsum function.
I am not very familiar with the Einstein summation convention and I believe that there are many others like myself.
I would be very grateful if explanations about the operation were added as comments.
As it is, I find it difficult to understand which parts correspond to what in the paper.
Many thanks if you could help me out.

v_out = torch.cat((v_out_h + self.rel_h, v_out_w + self.rel_w), dim=1)

Hi, are the following codes wrong?
"v_out_h, v_out_w = v_out.split(self.out_channels // 2, dim=1)
v_out = torch.cat((v_out_h + self.rel_h, v_out_w + self.rel_w), dim=1)"
shouldn't it be something as follows?
"k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1)
k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1)"
Because the relative distance embedding is supposed to added to the "key" instead of the "value", right?

problem with unfold

after

k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)

it gives error,

RuntimeError: shape '[2, 1, 16, 34, 34, -1]' is invalid for input of size 294912

this is because of use of unfold on k_out before this, so height, width is not consistent.

How can you v_out_h + self.rel_h

Hello ,the v_out_h 's shape is [n,c//2,kernel_sizeh,kernel_sizew] ,but the self.rel_h 's shape is [out_channels // 2, 1, 1, kernel_size, 1 ] ,how can you add these two different shape ? hope your early reply!

Can anyone train resnet50 successfully without NaN

Hi, I am facing issues with the Resnet50 model training on CIFAR-10. Even with lr of 0.01 its throwing Nan after around 10 epochs (suddenly), so, I am not quite sure how to train the resnet50 model. Hoping for a quick reply! Thanks.

2d embedding

Hi,

I'm confused about embedding in steam attention

emb = emb_logit_a + emb_logit_b

It seems that col and row emb are the same, so features are aggregated across cols and rows with the same softmax values. Isn't emb supposed to be 2 dimensional here? Like this:

        emb = emb_logit_a.unsqueeze(2) + emb_logit_b.unsqueeze(1) # [m, ks, ks] p(m, a, b)
        emb = F.softmax(emb.view(self.m, -1), dim=-1)
        emb = emb.view(self.m, 1, 1, 1, 1, self.kernel_size, self.kernel_size)
        v_out = v_out * emb

Loss is NaN

Hello,
I am testing your Resnet50 model with stem is True and at the first training step, my loss is NaN and the accuracy is decreasing? Is that a bug?
image

Also I didn't see this problem when I train the model ResNet 26.

Has anyone tried changing the batch size

I am having issues trying to change the batch size. The code produces a mismatch error each time I change the batch size. Below is the error I am getting;
RuntimeError: size mismatch, m1: [4 x 8192], m2: [2048 x 5] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:283

Error loading pretrained model

When I try to load pretrained model, I kept getting this error.

RuntimeError: Error(s) in loading state_dict for Model:
Missing key(s) in state_dict: "layer1.0.conv2.0.rel_h", "layer1.0.conv2.0.rel_w", "layer1.0.conv2.0.key_conv.weight", "layer1.0.conv2.0.query_conv.weight", "layer1.0.conv2.0.value_conv.weight", "layer2.0.conv2.0.rel_h", "layer2.0.conv2.0.rel_w", "layer2.0.conv2.0.key_conv.weight", "layer2.0.conv2.0.query_conv.weight", "layer2.0.conv2.0.value_conv.weight", "layer2.1.conv2.0.rel_h", "layer2.1.conv2.0.rel_w", "layer2.1.conv2.0.key_conv.weight", "layer2.1.conv2.0.query_conv.weight", "layer2.1.conv2.0.value_conv.weight", "layer3.0.conv2.0.rel_h", "layer3.0.conv2.0.rel_w", "layer3.0.conv2.0.key_conv.weight", "layer3.0.conv2.0.query_conv.weight", "layer3.0.conv2.0.value_conv.weight", "layer3.1.conv2.0.rel_h", "layer3.1.conv2.0.rel_w", "layer3.1.conv2.0.key_conv.weight", "layer3.1.conv2.0.query_conv.weight", "layer3.1.conv2.0.value_conv.weight", "layer3.2.conv2.0.rel_h", "layer3.2.conv2.0.rel_w", "layer3.2.conv2.0.key_conv.weight", "layer3.2.conv2.0.query_conv.weight", "layer3.2.conv2.0.value_conv.weight", "layer3.3.conv2.0.rel_h", "layer3.3.conv2.0.rel_w", "layer3.3.conv2.0.key_conv.weight", "layer3.3.conv2.0.query_conv.weight", "layer3.3.conv2.0.value_conv.weight", "layer4.0.conv2.0.rel_h", "layer4.0.conv2.0.rel_w", "layer4.0.conv2.0.key_conv.weight", "layer4.0.conv2.0.query_conv.weight", "layer4.0.conv2.0.value_conv.weight".
Unexpected key(s) in state_dict: "layer1.0.conv2.0.weight", "layer1.0.conv2.0.bias", "layer2.0.conv2.0.weight", "layer2.0.conv2.0.bias", "layer2.1.conv2.0.weight", "layer2.1.conv2.0.bias", "layer3.0.conv2.0.weight", "layer3.0.conv2.0.bias", "layer3.1.conv2.0.weight", "layer3.1.conv2.0.bias", "layer3.2.conv2.0.weight", "layer3.2.conv2.0.bias", "layer3.3.conv2.0.weight", "layer3.3.conv2.0.bias", "layer4.0.conv2.0.weight", "layer4.0.conv2.0.bias".

my code for loading pretrained model is below:

    file_path = 'path to file_ckpt.tar'
    checkpoint = torch.load(file_path)
    
    model.load_state_dict(checkpoint['state_dict'])
    model = nn.DataParallel(model, dim=0)
    model.cuda()
    start_epoch = checkpoint['epoch']
    best_acc = checkpoint['best_acc']

    model_parameters = checkpoint['parameters']

Excessive Memory Usage

ResNet-26 on CIFAR-10 (8.3M params) using your implementation without any changes fails to fit in memory in HAL and even on Colab - P100, T4 and P4. What was your system configuration for doing the benchmarks?

Large memory consumption

Hi, thanks for your nice work.
On a large dataset like ImageNet, proposed self-attention mechanism consumes large amounts of memory because of unfold operation (im2col).
One would want to share point-wise feature vector among sliding windows.
Do you have any idea?

how about replacing einsum with normal multiplication

in attention.py, class AttentionConv

replacing out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out)

with out = (out*v_out).sum(dim=5)

made running time more than 2x faster while training on IMAGENET (2 min vs 53s per 100 step, batchsize 25) which is still 3.5x slower than training normal ResNet on IMAGENET

(Not sure whether this model works for IMAGENET or not)

Train with IMAGENET

Dear @leaderj1001

Have You success to train with Imagenet Dataset? Because i found this error when use the Imagenet:

RuntimeError: CUDA out of memory. Tried to allocate 3.66 GiB (GPU 0; 10.91 GiB total capacity; 8.04 GiB already allocated; 1.66 GiB free; 8.05 GiB reserved in total by PyTorch)

Thank You

Problems about groups

Should not the key_conv, query_conv and value_conv be defined with group parameters? I am not clear what group parameters do in the following lines.

k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
v_out = v_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
q_out = q_out.view(batch, self.groups, self.out_channels // self.groups, height, width, 1)
out = q_out * k_out
out = F.softmax(out, dim=-1)
out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out).view(batch, -1, height, width)

A question about relative position embeddings

Hello, I want to ask you a question. In relative position embeddings , why should the number of channels be divided by 2(out_channels // 2)?

self.rel_h = nn.Parameter(torch.randn(out_channels // 2, 1, 1, kernel_size, 1), requires_grad=True) self.rel_w = nn.Parameter(torch.randn(out_channels // 2, 1, 1, 1, kernel_size), requires_grad=True)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.