Code Monkey home page Code Monkey logo

ecgsurvnet's Introduction

ECGSurvNet

ECGSurvNet is a deep survival neural network for predicting mortality risk from electrocardiogram (ECG). This repository demonstrates how to train and test ECGSurvNet on the open ECG dataset. ECGSurvNet predicts the patient’s risk of death from the waveform of ECG, which trained using the equations of Cox proportional hazards model as the loss function. Please refer to our paper for more details:

  • C Lin, "Mortality risk prediction of electrocardiogram via deep survival neural network as an extensive long-term cardiovascular outcome predictor", submitted to journal in 2022.

Requirements

You may need to have Rtools installed to compile the package. Use the above link for the installation of Rtools.

You need to have MXNet to train and inference the deep learning model. You can install CPU verions of MXNet by running the following line in your R console:

cran <- getOption("repos")
cran["dmlc"] <- "https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/CRAN/"
options(repos = cran)
install.packages("mxnet")

You need to have rhdf5 and data.table to decode and read the ECG data from SaMi-Trop dataset. You can install rhdf5 and data.table by running the following line in your R console:

# rhdf5
## try http:// if https:// URLs are not supported
source("https://bioconductor.org/biocLite.R")
biocLite("rhdf5")

# data.table
packageurl <- "https://cran.r-project.org/src/contrib/Archive/data.table/data.table_1.11.8.tar.gz"
install.packages(packageurl, repos=NULL, type="source")

You need to have ggplot2 and its dependencies installed to plot the loss during training processing, and you can install these packages by running the following line in your R console:

package_url <- "https://cran.r-project.org/src/contrib/Archive/pillar/pillar_1.4.4.tar.gz"
install.packages(package_url, repos = NULL, type="source")
package_url <- "https://cran.r-project.org/src/contrib/Archive/ggplot2/ggplot2_3.3.3.tar.gz"
install.packages(package_url,  repos = NULL, type = "source")  

You need to have survival with version 3.2-7 to get the c-index for validation. You can install specific version of survival by running the following line in your R console:

packageurl <- "https://cran.r-project.org/src/contrib/Archive/survival/survival_3.2-7.tar.gz"
install.packages(packageurl, repos=NULL, type="source")

Data preparation

We use the SaMi-Trop dataset as the example data. The SaMi-Trop cohort is an open dataset with annotations of mortality and the correspondent ECG traces. In this repository, we randomly divided the dataset into training (80%) and validation (20%) sets.
You can use the code 'code/1. processing data/1. download Sami-Trop.R' to download the SaMi-Trop dataset, and use the codes 'code/1. processing data/2. pre-processing data.R' to pre-process the dataset for training and validating ECGSurvNet.

Deep learning model: ECGSurvNet

The model can be trained using the script 'code/train.R' once the data is prepared by 'code/1. processing data/2. pre-processing data.R'. Alternatively, pre-trained weights of the ECGSurvNet is available at 'model/ECGSurvNet/ECGSurvNet-0000.params'.

A modified residual net (ResNet) with 1D convolutional layer is used in this repository, which is described in the script 'code/train.R':

model_symbol <- ECGSurvNet(indata = var_list[["data"]], start_filter = 32, inverted_coef = 4,
                           num_filters = c(32, 64, 64, 128), num_unit = c(3, 3, 6, 4), end_filters = c(512))
  • input: dimension = (2800, 1, 12, N). The input tensor contains the 2,800 sequence signals from each ECG leads. In the SaMi-Trop dataset, ECG was sampled at 400 Hz but some data was recorded with a duration of 10 seconds and others of 7 seconds. The ECG was fill with zeros on both size in order to make data have same size with a length of 4,096 points. For detail of ECG data, please ref to SaMi-Trop dataset. We crop a length of 2,800 points from the middle of original ECG for model training and validation. The final tensor consisted the sequence signals from 12 different ECG leads.

  • output: shape = (N). The predicted mortality risk from the ECG.

Performance

You can evaluate its success on validation set. The traditional Cox regression model was used as the baseline comparison, which was fitted using covariate data including age and sex. An example script of validation can be found in 'code/3. evaluation/evaluation_ECGSurvNet.R', and the performance of pre-trained ECGSurvNet is summarized as following:

message("C-index of Cox model using age and sex as covariates: ", round(cox_age_sex[["concordance"]][6], digits = 4))
>> C-index of Cox model using age and sex as covariates: 0.6344

message("C-index of Cox model using the output of ECGSurvNet as covariates: ", round(cox_ecg[["concordance"]][6], digits = 4))
>> C-index of Cox model using the output of ECGSurvNet as covariates: 0.6553

message("C-index of Cox model using age, sex, and the output of ECGSurvNet as covariates: ", round(cox_age_sex_ecg[["concordance"]][6], digits = 4))
>> C-index of Cox model using age, sex, and the output of ECGSurvNet as covariates: 0.6754

The performance of pre-trained ECGSurvNet might be fluctuating in other dataset because we only used about ~1,200 ECG records to train the ECGSurvNet in this repository.

How to cite

If you use this code in your work, please cite.

ecgsurvnet's People

Contributors

imshepherd avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.