Code Monkey home page Code Monkey logo

keras-attention-mechanism's Introduction

Keras Attention Mechanism

license dep1 dep2

Simple attention mechanism implemented in Keras for the following layers:

  • Dense (attention 2D block)
  • LSTM, GRU (attention 3D block)


Example: Attention block

Dense Layer

inputs = Input(shape=(input_dims,))
attention_probs = Dense(input_dims, activation='softmax', name='attention_probs')(inputs)
attention_mul = merge([inputs, attention_probs], output_shape=input_dims, name='attention_mul', mode='mul')

Let's consider this Hello World example:

  • A vector v of 32 values as input to the model (simple feedforward neural network).
  • v[1] = target.
  • Target is binary (either 0 or 1).
  • All the other values of the vector v (v[0] and v[2:32]) are purely random and do not contribute to the target.

We expect the attention to be focused on v[1] only, or at least strongly. We recap the setup with this drawing:

Attention Mechanism explained

The first two are samples taken randomly from the training set. The last plot is the attention vector that we expect. A high peak indexed by 1, and close to zero on the rest.

Let's train this model and visualize the attention vector applied to the inputs:

Attention Mechanism explained

We can clearly see that the network figures this out for the inference.

Behind the scenes

The attention mechanism can be implemented in three lines with Keras:

inputs = Input(shape=(input_dims,))
attention_probs = Dense(input_dims, activation='softmax', name='attention_probs')(inputs)
attention_mul = merge([inputs, attention_probs], output_shape=32, name='attention_mul', mode='mul')

We apply a Dense - Softmax layer with the same number of output parameters than the Input layer. The attention matrix has a shape of input_dims x input_dims here.

Then we merge the Inputs layer with the attention layer by multiplying element-wise.

Finally, the activation vector (probability distribution) can be derived with:

attention_vector = get_activations(m, testing_inputs_1, print_shape_only=True)[1].flatten()

Where 1 is the index of definition of the attention layer in the model definition (Inputs is indexed by 0).

Recurrent Layers (LSTM, GRU...)

Application of attention at input level

We consider the same example as the one used for the Dense layers. The attention index is now on the 10th value. We therefore expect an attention spike around this value. There are two main ways to apply attention to recurrent layers:

  • Directly on the inputs (same as the Dense example above): APPLY_ATTENTION_BEFORE_LSTM = True

Attention vector applied on the inputs (before)

Application of attention on the LSTM's output

  • After the LSTM layer: APPLY_ATTENTION_BEFORE_LSTM = False

Attention vector applied on the output of the LSTM layer (after)

Both have their own advantages and disadvantages. One obvious advantage of applying the attention directly at the inputs is that we clearly understand this space. The high dimensional space spanned by the LSTM might be a bit trickier to interpret, although they share the time steps in common with the inputs (return_sequences=True is used here).

Attention of multi dimensional time series

Also, sometimes, the time series can be N-dimensional. It could be interesting to have one attention vector per dimension. Let's say we have a 2-D time series on 20 steps. Setting SINGLE_ATTENTION_VECTOR = False gives an attention vector of shape (20, 2). If SINGLE_ATTENTION_VECTOR is set to True, it means that the attention vector will be of shape (20,) and shared across the input dimensions.

  • SINGLE_ATTENTION_VECTOR = False

Attention defined per time series (each TS has its own attention)

  • SINGLE_ATTENTION_VECTOR = True

Attention shared across all the time series

Resources

keras-attention-mechanism's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.