Code Monkey home page Code Monkey logo

eeg-atcnet's Introduction

EEG-ATCNet

PWC

This repository provides code for the Attention Temporal Convolutional Network (ATCNet) proposed in the paper: Physics-informed attention temporal convolutional network for EEG-based motor imagery classification

Authors: Hamdi Altaheri, Ghulam Muhammad, Mansour Alsulaiman

Center of Smart Robotics Research, King Saud University, Saudi Arabia

Updates:

  • The regularization parameters of ATCNet have been modified, resulting in an enhancement in the model's performance and fortifying it against overfitting.
  • The current main_TrainTest.py file, following the training and evaluation method outlined in Paper 1 and paper 2, has been identified as not aligning with industry best practices. In response, we strongly recommend adopting the methodology implemented in the refined main_TrainValTest.py file. This updated version splits the data into train/valid/test sets, following the guidelines detailed in this post (Option 2).

In addition to the proposed ATCNet model, the models.py file includes the implementation of other related methods, which can be compared with ATCNet, including:

The following table shows the performance of ATCNet and other reproduced models based on the methodology defined in the main_TrainValTest.py file:

Model #params BCI Competition IV-2a dataset (BCI 4-2a) High Gamma Dataset (HGD)*
training time (m) 1,2 accuracy (%) training time (m) 1,2 accuracy (%)
ATCNet 113,732 13.5 81.10 62.6 92.05
TCNet_Fusion 17,248 8.8 69.83 65.2 89.73
EEGTCNet 4,096 7.0 65.36 36.8 87.80
MBEEG_SENet 10,170 15.2 69.21 104.3 90.13
EEGNet 2,548 6.3 68.67 36.5 88.25
DeepConvNet 553,654 7.5 42.78 43.9 87.53
ShallowConvNet 47,364 8.2 67.48 61.8 87.00
1 using Nvidia GTX 1080 Ti 12GB
2 (500 epochs, without early stopping)
* please note that HGD is for "executed movements" NOT "motor imagery"

This repository includes the implementation of the following attention schemes in the attention_models.py file:

These attention blocks can be called using the attention_block(net, attention_model) method in the attention_models.py file, where 'net' is the input layer and 'attention_model' indicates the type of the attention mechanism, which has five options: None, 'mha', 'mhla', 'cbam', and 'se'.

Example: 
    input = Input(shape = (10, 100, 1))   
    block1 = Conv2D(1, (1, 10))(input)
    block2 = attention_block(block1,  'mha') # mha: multi-head self-attention
    output = Dense(4, activation="softmax")(Flatten()(block2))

The preprocess.py file loads and divides the dataset based on two approaches:

  1. Subject-specific (subject-dependent) approach. In this approach, we used the same training and testing data as the original BCI-IV-2a competition division, i.e., trials in session 1 for training, and trials in session 2 for testing.
  2. Leave One Subject Out (LOSO) approach. LOSO is used for Subject-independent evaluation. In LOSO, the model is trained and evaluated by several folds, equal to the number of subjects, and for each fold, one subject is used for evaluation and the others for training. The LOSO evaluation technique ensures that separate subjects (not visible in the training data) are used to evaluate the model.

The get_data() method in the preprocess.py file is used to load the dataset and split it into training and testing. This method uses the subject-specific approach by default. If you want to use the subject-independent (LOSO) approach, set the parameter LOSO = True.

About ATCNet

ATCNet is inspired in part by the Vision Transformer (ViT). ATCNet differs from ViT by the following:

  • ViT uses single-layer linear projection while ATCNet uses multilayer nonlinear projection, i.e., convolutional projection specifically designed for EEG-based brain signals.
  • ViT consists of a stack of encoders where the output of the previous encoder is the input of the subsequent. ATCNet consists of parallel encoders and the outputs of all encoders are concatenated.
  • The encoder block in ViT consists of a multi-head self-attention (MHA) followed by a multilayer perceptron (MLP), while in ATCNet the MHA is followed by a temporal convolutional network (TCN).
  • The first encoder in ViT receives the entire input sequence, while each encoder in ATCNet receives a shifted window from the input sequence.

ATCNet vs Vit

ATCNet model consists of three main blocks:

  1. Convolutional (CV) block: encodes low-level spatio-temporal information within the MI-EEG signal into a sequence of high-level temporal representations through three convolutional layers.
  2. Attention (AT) block: highlights the most important information in the temporal sequence using a multi-head self-attention (MHA).
  3. Temporal convolutional (TC) block: extracts high-level temporal features from the highlighted information using a temporal convolutional layer
  • ATCNet model also utilizes the convolutional-based sliding window to augment MI data and boost the performance of MI classification efficiently.

Visualize the transition of data in the ATCNet model.

The components of the proposed ATCNet model

Development environment

Models were trained and tested by a single GPU, Nvidia GTX 2070 8GB (Driver Version: 512.78, CUDA 11.3), using Python 3.7 with TensorFlow framework. Anaconda 3 was used on Ubuntu 20.04.4 LTS and Windows 11. The following packages are required:

  • TensorFlow 2.7
  • matplotlib 3.5
  • NumPy 1.20
  • scikit-learn 1.0
  • SciPy 1.7

Dataset

The BCI Competition IV-2a dataset needs to be downloaded and the data path placed at 'data_path' variable in main.py file. The dataset can be downloaded from here.

References

If you find this work useful in your research, please use the following BibTeX entry for citation

@article{9852687,
  title={Physics-Informed Attention Temporal Convolutional Network for EEG-Based Motor Imagery Classification},
  author={Altaheri, Hamdi and Muhammad, Ghulam and Alsulaiman, Mansour},
  journal={IEEE Transactions on Industrial Informatics},
  year={2023},
  volume={19},
  number={2},
  pages={2249--2258},
  publisher={IEEE}
  doi={10.1109/TII.2022.3197419}
}

@article{10142002,
  title={Dynamic convolution with multilevel attention for EEG-based motor imagery decoding}, 
  author={Altaheri, Hamdi and Muhammad, Ghulam and Alsulaiman, Mansour},
  journal={IEEE Internet of Things Journal}, 
  year={2023},
  volume={10},
  number={21},
  pages={18579-18588},
  publisher={IEEE}
  doi={10.1109/JIOT.2023.3281911}
}

@article{altaheri2023deep,
  title={Deep learning techniques for classification of electroencephalogram (EEG) motor imagery (MI) signals: A review},
  author={Altaheri, Hamdi and Muhammad, Ghulam and Alsulaiman, Mansour and Amin, Syed Umar and Altuwaijri, Ghadir Ali and Abdul, Wadood and Bencherif, Mohamed A and Faisal, Mohammed},
  journal={Neural Computing and Applications},
  year={2023},
  volume={35},
  number={20},
  pages={14681--14722},
  publisher={Springer}
  doi={10.1007/s00521-021-06352-5}
}

eeg-atcnet's People

Contributors

altaheri avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

eeg-atcnet's Issues

Question about the size of input data

Hi,

I read your paper. I'm really impressive your method.

However, I have a question about your data size.
In your paper, the temporal size of input data is 1125.
The sampling rate of the dataset is 250 Hz, and, based on the dataset description, the time for conducting the task is four sec., including cue time.

Screen Shot 2022-09-01 at 3 59 22 PM

I think 1125 (temporal size of input data) was calculated by 250 * 4.5 sec..

Screen Shot 2022-09-01 at 4 01 57 PM

So, why did you set the task to 4.5 seconds?

model

Hello!
Regarding the model, I would like to ask why a convolution accuracy similar to EEGNet can reach around 81% after ablation experiments?

LOSO : True

When setting LOSO : True in dataset_conf dictionary im getting error:
line 64, in load_data_LOSO
elif (X_train == []):
ValueError: operands could not be broadcast together with shapes (576,22,1750) (0,)

tf version 2.14.0

Am I missing something?

Dataset download

Hello, the dataset download from the provided link does not seem to be working. Is there any alternative for us to download the .mat dataset from? Thanks in advance.

Ensemble

Hi @Altaheri,

thanks for making your code public.
As far as I understand the code, the sliding window procedure does not only introduce n_windows more function calls but also introduces n_windows more model paramteres.

This for-loop creates a new "AT-Block" (consisting of an attention block, a TCN and a Dense layer) per window. This would effectively make your ATCNet an Ensemble model where each window would have its own "AT-Block". Is this a wanted behaviour? As far as I understood the paper I thought that there is only one "AT-Block" that is shared between all windows.

code

The paper named [Attention-Inception and Long- Short-Term Memory-Based Electroencephalography Classification for Motor Imagery Tasks in Rehabilitation] ,do you prepare to publish the code?

version available?

hello, glad you can share the code. Is there a pytorch version available?

pytorch

It is an honor to know about this model, which has achieved great results. I would like to inquire if you are planning to release a version of Pytorch, and if so, when will it be released.

Cannot find saved models

2024-01-30 20:37:59.364344: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
File "C:\EEG-ATCNet-main\main_TrainValTest.py", line 425, in
run()
File "C:\EEG-ATCNet-main\main_TrainValTest.py", line 421, in run
test(model, dataset_conf, results_path)
File "C:\EEG-ATCNet-main\main_TrainValTest.py", line 246, in test
runs = os.listdir(results_path+"\saved models")
FileNotFoundError: [WinError 3] The system cannot find the path specified: 'C:\EEG-ATCNet-main\results\saved models'
it is showing this error i'm working on this for past 2 days trying to make it work can someone provide their working code for windows and share their github repository and tell me the exact procedure to run the code

Reproducability and Randomness

Hi @Altaheri,

I am trying to reproduce your results which honestly looked too good to be true to me.
After implementing everything I was not even close to your results. Then I observed your training routine in detail and found 3 major reasons/flaws why your model performs "so well" (it doesn't). After I changed my pipeline to look like yours I got similar results.
The problem with those results however is, that they heavily rely on the randomness of the training routine and the missing independency of your test set:

  1. Your validation set equals your test set and you choose the best checkpoint based on the val_acc = test_acc . This effectively makes your model dependent on the test set. Because of the highly fluctuating val_acc=test_acc the specific choice of the checkpoint has a significant impact on your "test" performance. Furthermore (if a separate validation split is used) one typically uses the val_loss instead of the val_acc, because the acc might have lucky peaks and does not measure the uncertainty of the model (extreme case: a correct 1-0-0-0 prediction yields the same accuracy as a 0.26-0.25-0.25-0.24 prediction). If validation and test set are independent the lowest val_loss typically yields the highest test_acc.
  2. You make multiple runs to try out different random seeds. This is generally a good thing but instead of taking the best run you have to average over all seeds. Otherwise you are just exploiting the randomness of the process (more information). Additionally you should (re-)set the random seed before every run s.t. you get the same results every time you run your code.
  3. Same line as 2. you not only exploit the randomness of the process by choosing the best seed over all subjects, you even do this independently per subject. That means you are effectively choosing the best random configuration out of 10^9 (one billion!) possible configurations.

To backup my findings, I ran a few (subject-specific) experiments for subject 2 (bad subject) and subject 3 (good subject). I used EarlyStopping and chose either the last checkpoint or the best and ran the experiment with 10 different random seeds.
Results:
subject 2:
Average accuracy (last ckpt): 63.3+-3.0
Optimal seed accuracy (last ckpt): 67.4
Average accuracy (best ckpt): 67.0+-3.8
Optimal seed accuracy (best ckpt): 71.9
subject 3:
Average accuracy (last ckpt): 90.6+-2.7
Optimal seed accuracy (last ckpt): 94.8
Average accuracy (best ckpt): 94.6+-0.8
Optimal seed accuracy (best ckpt): 95.8

  1. The average accuracy of the best checkpoint is around 4% better than the one of the last checkpoint.
  2. The optimal choice of seed yields another 1-5% accuracy.
  3. If I choose the best random seed of subject 2 (seed=7) and use it for subject 3 I would only get 94.8% which confirms my third point. The other way round it would be seed=0 and only 64.6% test_acc for subject 2 - a 7.3% decrease!

If you have any further questions, feel free to ask!

Process finished with exit code -1073740791 (0xC0000409)

After the code modification, this line shows "Process finished with exit code -1073740791 (0xC0000409)". There is no problem with the environment configuration and the GPU can be used normally. Is it a graphics card issue

question about the setting of n_train when LOSO is used

Thanks for your excellent work!
I've got a question about the setting of n_train when LOSO is used.
In your code, you set n_train as 10 when the model is trained using subject-dependent approach, which is, the default setting. Is n_train still be set as 10 when subject-independent approach is adopted(i.e. LOSO==True)??? As long as LOSO is adopted, it will be much more time consuming if n_train is 10 comparing to its subject-dependent counterpart.
Looking forward for your reply!

I want to add a program

Hello, thank you for sharing the code,I added a piece of de-noising code to the ATCNet model, which reduced the accuracy of the verification set. I don't understand why. Do you understand?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.