Code Monkey home page Code Monkey logo

3d-attention-in-tf2--position-channel-attention's Introduction

3D-Attention-in-tf2--Additive,Position,Channel Attention

In recent years we have seen a number of implementations with some additional attention to the task(classification/segmentation). This repo contains the 3D implementation of the commonly used attention mechanism for imaging.

Additive Attention Gate (AG)-3D

Oktay, Ozan, et al. "Attention u-net: Learning where to look for the pancreas." arXiv preprint arXiv:1804.03999 (2018).

dsc

code

def Attention_mechanism(X,G,out_filters,kernel_size=1,strides=(1, 1, 1),use_bias=False,
                 kernel_initializer=tf.keras.initializers.VarianceScaling(distribution='uniform'),
                 bias_initializer=tf.zeros_initializer(),
                 kernel_regularizer=tf.keras.regularizers.l2(l=0.001),
                 bias_regularizer=None,
                 **kwargs):

    conv_params={'padding': 'same',
                   'use_bias': use_bias,
                   'kernel_initializer': kernel_initializer,
                   'bias_initializer': bias_initializer,
                   'kernel_regularizer': kernel_regularizer,
                   'bias_regularizer': bias_regularizer}

    ###input from the resolution.
    Original_x=G
    ###
    X1=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=1,strides=1,**conv_params)(X)
    X1=tf.keras.layers.BatchNormalization()(X1)
    G1=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=1,strides=1,**conv_params)(G)
    G1=tf.keras.layers.BatchNormalization()(G1)
    ##Adding
    X1_G1=X1+G1
    #Applying Relu
    A1=tf.nn.relu6(X1_G1)
    #Applying convolution again
    A1=tf.keras.layers.Conv3D(filters=out_filters,kernel_size=1,strides=1,**conv_params)(A1)
    #Sigmoid
    A1=tf.keras.activations.sigmoid(A1)
    final_attention=tf.math.multiply(Original_x,A1)

    return final_attention

Postion Attention-3D

Fu, Jun, et al. "Dual attention network for scene segmentation." 
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

dsc

code

def Position_attention(postion_attention_input):

    #--Getting the Shape of the inputs
    in_shp = postion_attention_input.get_shape().as_list()


    C1=tf.keras.layers.Conv3D(filters=int(in_shp[4]/8),kernel_size=1,strides=(1,1,1))(postion_attention_input)
    C1_shp = C1.get_shape().as_list()

    ##--first-Batch
    F1_HWDxC=tf.reshape(C1, [-1, C1_shp[1]*C1_shp[2]*C1_shp[3],C1_shp[4]])
    print(F1_HWDxC.get_shape())

    ##--Seconr-Batch
    F2_CxHWD=tf.transpose(F1_HWDxC,perm=[0, 2, 1])
    F2_CxHWD=tf.matmul(F1_HWDxC,F2_CxHWD)
    F2_CxHWD=tf.keras.activations.softmax(F2_CxHWD)
    print(F2_CxHWD.get_shape())


    ##--thir-Batch
    C2=tf.keras.layers.Conv3D(filters=in_shp[4],kernel_size=1,strides=(1,1,1))(postion_attention_input)
    F3_HWDxC=tf.reshape(C2, [-1, in_shp[1]*in_shp[2]*in_shp[3],in_shp[4]])
    F3xF2=tf.matmul(F2_CxHWD,F3_HWDxC)
    F3=tf.reshape(F3xF2,[-1, in_shp[1],in_shp[2],in_shp[3],in_shp[4]])
    print(F3.get_shape())
    print(postion_attention_input.get_shape())
    
    postion_attention_output=tf.keras.layers.Multiply()([postion_attention_input,F3])
    postion_attention_output=tf.keras.layers.Conv3D(filters=in_shp[4],kernel_size=1,strides=(1,1,1))(postion_attention_output)


    return postion_attention_output

Channel Attention-3D

Fu, Jun, et al. "Dual attention network for scene segmentation." 
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

dsc

code

def Channel_attention(Channel_attention_input):

    in_shp = Channel_attention_input.get_shape().as_list()

    ##--first-Batch
    channel_C1=tf.reshape(Channel_attention_input, [-1, in_shp[1]*in_shp[2]*in_shp[3],in_shp[4]])

    ##--Seconr-Batch
    channel_C2=tf.transpose(channel_C1,perm=[0, 2, 1])
    channel_C2=tf.matmul(channel_C1,channel_C2)
    channel_C2=tf.keras.activations.softmax(channel_C2)

    channel_C3=tf.matmul(channel_C2,channel_C1)
    channel_C3=tf.reshape(channel_C3,[-1, in_shp[1],in_shp[2],in_shp[3],in_shp[4]])

    Channel_attention_output=tf.keras.layers.Multiply()([channel_C3,Channel_attention_input])
    Channel_attention_output=tf.keras.layers.Conv3D(filters=in_shp[4],kernel_size=1,strides=(1,1,1))( Channel_attention_output)

    return Channel_attention_output

3d-attention-in-tf2--position-channel-attention's People

Contributors

fitushar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.