Code Monkey home page Code Monkey logo

saits's Introduction

SAITS Title

The official code repository for paper SAITS: Self-Attention-based Imputation for Time Series.

‼️Kind reminder: This document can solve many common questions and help you, please read it carefully before you run the code.

πŸ“£ Attention please
SAITS now is available in PyPOTS, a Python toolbox for data mining on POTS (Partially-Observed Time Series). An example of training SAITS for imputing dataset PhysioNet-2012 is shown below. With PyPOTS, easy peasy! πŸ˜‰

πŸ‘‰ Click here to see the example πŸ‘€
# Install PyPOTS first: pip install pypots
import numpy as np
from sklearn.preprocessing import StandardScaler
from pypots.data import load_specific_dataset, mcar, masked_fill
from pypots.imputation import SAITS
from pypots.utils.metrics import cal_mae
# Data preprocessing. Tedious, but PyPOTS can help. πŸ€“
data = load_specific_dataset('physionet_2012')  # For datasets in PyPOTS database, PyPOTS will automatically download and extract it.
X = data['X']
num_samples = len(X['RecordID'].unique())
X = X.drop('RecordID', axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_intact, X, missing_mask, indicating_mask = mcar(X, 0.1) # hold out 10% observed values as ground truth
X = masked_fill(X, 1 - missing_mask, np.nan)
# Model training. This is PyPOTS showtime. πŸ’ͺ
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_head=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
saits.fit(X)  # train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model.
imputation = saits.impute(X)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

β¦Ώ Motivation: SAITS is developed primarily to help overcome the drawbacks (slow speed, memory constraints, and compounding error) of RNN-based imputation models and to obtain the state-of-the-art (SOTA) imputation accuracy on partially-observed time series.

⦿ Performance: SAITS outperforms BRITS by 12% ∼ 38% in MAE (mean absolute error) and achieves 2.0 ∼ 2.6 times faster training speed. Furthermore, SAITS outperforms Transformer (trained by our joint-optimization approach) by 2% ∼ 13% in MAE with a more efficient model structure (to obtain comparable performance, SAITS needs only 15% ∼ 30% parameters of Transformer). Compared to another SOTA self-attention imputation model NRTSI, SAITS achieves 7% ∼ 39% smaller mean squared error (above 20% in nine out of sixteen cases), meanwhile, needs much fewer parameters and less imputation time in practice. Please refer to our full paper for more details about SAITS' performance.

❖ Brief Graphical Illustration of Our Methodology

Here we only show the two main components of our method: the joint-optimization training approach and SAITS structure. For the detailed description and explanation, please read our full paper.

Training approach

Fig. 1: Training approach

SAITS architecture

Fig. 2: SAITS structure

❖ Repository Structure

The implementation of SAITS is in dir modeling. We give configurations of our models in dir configs, provide the dataset links and preprocessing scripts in dir dataset_generating_scripts. Dir NNI_tuning contains the hyper-parameter searching configurations.

❖ Development Environment

All dependencies of our development environment are listed in file conda_env_dependencies.yml. You can quickly create a usable python environment with an anaconda command conda env create -f conda_env_dependencies.yml. ❗️Note that this file is for Linux platform with GPU, but you still can use it for reference of dependency libraries.

❖ Datasets

For datasets downloading and generating, please check out the scripts in dir dataset_generating_scripts.

❖ Quick Run

Generate the dataset you need first. To do so, please check out the generating scripts in dir dataset_generating_scripts.

After data generation, train and test your model, for example,

# create a dir to save logs and results
mkdir NIPS_results

# train a model
nohup python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    > NIPS_results/PhysioNet2012_SAITS_best.out &

# during training, you can run the blow command to read the training log
less NIPS_results/PhysioNet2012_SAITS_best.out

# after training, pick the best model and modify the path of the model for testing in the config file, then run the below command to test the model
python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    --test_mode

❗️Note that paths of datasets and saving dirs may be different on personal computers, please check them in the configuration files.

❖ Reference

If you find SAITS is helpful to your research, please cite our paper as below and ⭐️star this repository to make others notice our work. πŸ€—

@article{Du2022SAITS,
      title={{SAITS: Self-Attention-based Imputation for Time Series}}, 
      author={Wenjie Du and David Cote and Yan Liu},
      year={2022},
      eprint={2202.08516},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

or

Wenjie Du, David Cote, and Yan Liu. "SAITS: Self-Attention-based Imputation for Time Series." ArXiv abs/2202.08516

❖ Acknowledgments

Thanks to Mitacs and NSERC (Natural Sciences and Engineering Research Council of Canada) for funding support. Thanks to Ciena for providing computing resources. Thanks to all reviewers for helping improve the quality of this paper. And thank you all for your attention to this work!

✨ Stars/forks/issues/PRs are all welcome!

πŸ‘ Click to View Stargazers and Forkers:

Stargazers repo roster for @WenjieDu/SAITS

Forkers repo roster for @WenjieDu/SAITS

❖ Last but Not Least

Take a look at my GitHub homepage to know more about me. πŸ˜ƒ

If you have any additional questions or have interests in collaboration, please feel free to drop me an email or PM me on / !

saits's People

Contributors

wenjiedu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.