Code Monkey home page Code Monkey logo

retain's Introduction

RETAIN

RETAIN is an interpretable predictive model for healthcare applications. Given patient records, it can make predictions while explaining how each medical code (diagnosis codes, medication codes, or procedure codes) at each visit contributes to the prediction. The interpretation is possible due to the use of neural attention mechanism.

RETAIN Interpretation Demo Using RETAIN, you can calculate how positively/negatively each medical code (diagnosis, medication, or procedure code) at different visits contributes to the final score. In this case, we are predicting whether the given patient will be diagnosed with Heart Failure (HF). You can see that the codes that are highly related to HF makes positive contributions. RETAIN also learns to pay more attention to new information than old information. You can see that Cardiac Dysrythmia (CD) makes a bigger contribution as it occurs in the more recent visit.

Relevant Publications

RETAIN implements an algorithm introduced in the following paper:

RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism
Edward Choi, Mohammad Taha Bahadori, Joshua A. Kulas, Andy Schuetz, Walter F. Stewart, Jimeng Sun,
NIPS 2016, pp.3504-3512

Notice

The RETAIN paper formulates the model as being able to make prediction at each timestep (e.g. try to predict what diagnoses the patient will receive at each visit), and treats sequence classification (e.g. Given a patient record, will he be diagnosed with heart failure in the future?) as a special case, since sequence classification makes the prediction at the last timestep only.

This code, however, is implemented to perform the sequence classification task. For example, you can use this code to predict whether the given patient is a heart failure patient or not. Or you can predict whether this patient will be readmitted in the future. The more general version of RETAIN will be released in the future.

Running RETAIN

STEP 1: Installation

  1. Install python, Theano. We use Python 2.7, Theano 0.8. Theano can be easily installed in Ubuntu as suggested here

  2. If you plan to use GPU computation, install CUDA

  3. Download/clone the RETAIN code

STEP 2: Fast way to test RETAIN with MIMIC-III
This step describes how to train RETAIN, with minimum number of steps using MIMIC-III, to predict patients' mortality using their visit records.

  1. You will first need to request access for MIMIC-III, a publicly avaiable electronic health records collected from ICU patients over 11 years.

  2. You can use "process_mimic.py" to process MIMIC-III dataset and generate a suitable training dataset for RETAIN. Place the script to the same location where the MIMIC-III CSV files are located, and run the script. The execution command is python process_mimic.py ADMISSIONS.csv DIAGNOSES_ICD.csv PATIENTS.csv <output file>.

  3. Run RETAIN using the ".seqs" and ".morts" file generated by process_mimic.py. The ".seqs" file contains the sequence of visits for each patient. Each visit consists of multiple diagnosis codes. However we recommend using ".3digitICD9.seqs" file instead, as the results will be much more interpretable. (Or you could use Single-level Clical Classification Software for ICD9 to decrease the number of codes to a couple of hundreds, which will even more improve the performance) The ".morts" file contains the sequence of mortality labels for each patient. The command is python retain.py <3digitICD9.seqs file> 942 <morts file> <output path> --simple_load --n_epochs 100 --keep_prob_context 0.8 --keep_prob_emb 0.5. 942 is the number of the entire 3-digit ICD9 codes used in the dataset.

  4. To test the model for interpretation, please refer to Step 6. I personally found that perinatal jaundice (ICD9 774) has high correlation with mortality.

  5. The model reaches AUC above 0.8 with the above command, but the interpretations are not super clear. You could tune the hyper-parameters, but I doubt things will dramatically improve. After all, only 7,500 patients made more than a single hospital visit, and most of them have only two visits.

STEP 3: How to prepare your own dataset

  1. RETAIN's training dataset needs to be a Python cPickled list of list of list. The outermost list corresponds to patients, the intermediate to the visit sequence each patient made, and the innermost to the medical codes (e.g. diagnosis codes, medication codes, procedure codes, etc.) that occurred within each visit. First, medical codes need to be converted to an integer. Then a single visit can be seen as a list of integers. Then a patient can be seen as a list of visits. For example, [5,8,15] means the patient was assigned with code 5, 8, and 15 at a certain visit. If a patient made two visits [1,2,3] and [4,5,6,7], it can be converted to a list of list [[1,2,3], [4,5,6,7]]. Multiple patients can be represented as [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], which means there are two patients where the first patient made two visits and the second patient made three visits. This list of list of list needs to be pickled using cPickle. We will refer to this file as the "visit file".

  2. The total number of unique medical codes is required to run RETAIN. For example, if the dataset is using 14,000 diagnosis codes and 11,000 procedure codes, the total number is 25,000.

  3. The label dataset (let us call this "label file") needs to be a Python cPickled list. Each element corresponds to the true label of each patient. For example, 1 can be the case patient and 0 can be the control patient. If there are two patients where only the first patient is a case, then we should have [1,0].

  4. The "visit file" and "label file" need to have 3 sets respectively: training set, validation set, and test set. The file extension must be ".train", ".valid", and ".test" respectivley.
    For example, if you want to use a file named "my_visit_sequences" as the "visit file", then RETAIN will try to load "my_visit_sequences.train", "my_visit_sequences.valid", and "my_visit_sequences.test".
    This is also true for the "label file"

  5. You can use the time information regarding the visits as an additional source of information. Let us call this "time file". Note that the time information could be anything: duration between consecutive visits, cumulative number of days since the first visit, etc. "time file" needs to be prepared as a Python cPickled list of list. The outermost list corresponds to patients, and the innermost to the time information of each visit. For example, given a "visit file" [[[1,2,3], [4,5,6,7]], [[2,4], [8,3,1], [3]]], its corresponding "time file" could look like [[0, 15], [0, 45, 23]], if we are using the duration between the consecutive visits. (of course the numbers are fake, and I've set the duration for the first visit to zero.) Use --time_file <path to time file> option to use "time file" Remember that the ".train", ".valid", ".test" rule also applies to the "time file" as well.

Additional: Using your own medical code representations
RETAIN internally learns the vector representation of medical codes while training. These vectors are initialized with random values of course.
You can, however, also use your own medical code representations, if you have one. (They can be trained by using Skip-gram like algorithms. Refer to Med2Vec or this for further details.) If you want to provide the medical code representations, it has to be a list of list (basically a matrix) of N rows and M columns where N is the number of unique codes in your "visit file" and M is the size of the code representations. Specify the path to your code representation file using --embed_file <path to embedding file>. Additionally, even if you use your own medical code representations, you can re-train (a.k.a fine-tune) them as you train RETAIN. Use --embed_finetune option to do this. If you are not providing your own medical code representations, RETAIN will use randomly initialized one, which obviously requires this fine-tuning process. Since the default is to use the fine-tuning, you do not need to worry about this.

STEP 4: Running RETAIN

  1. The minimum input you need to run RETAIN is the "visit file", the number of unique medical codes in the "visit file", the "label file", and the output path. The output path is where the learned weights and the log will be saved.
    python retain.py <visit file> <# codes in the visit file> <label file> <output path>

  2. Specifying --verbose option will print training process after each 10 mini-batches.

  3. You can specify the size of the embedding W_emb, the size of the hidden layer of the GRU that generates alpha, and the size of the hidden layer of the GRU that generates beta. The respective commands are --embed_size <integer>, --alpha_hidden_dim_size <integer>, and --beta_hidden_dim_size <integer>. For example --alpha_hidden_dim_size 128 will tell RETAIN to use a GRU with 128-dimensional hidden layer for generating alpha.

  4. Dropouts are applied to two places: 1) to the input embedding, 2) to the context vector c_i. The respective dropout rates can be adjusted using --keep_prob_embed {0.0, 1.0} and --keep_prob_context {0.0, 1.0}. Dropout values affect the performance so it is recommended to tune them for your data.

  5. L2 regularizations can be applied to W_emb, w_alpha, W_beta, and w_output.

  6. Additional options can be specified such as the size of the batch size, the number of epochs, etc. Detailed information can be accessed by python retain.py --help

  7. My personal recommendation: use mild regularization (0.0001 ~ 0.001) on all four weights, and use moderate dropout on the context vector only. But this entirely depends on your data, so you should always tune the hyperparameters for yourself.

STEP 5: Getting your results

RETAIN checks the AUC of the validation set after each epoch, and if it is higher than all previous values, it will save the current model. The model file is generated by numpy.savez_compressed.

Step 6: Testing your model

  1. Using the file "test_retain.py", you can calculate the contributions of each medical code at each visit. First you need to have a trained model that was saved by numpy.savez_compressed. Note that you need to know the configuration with which you trained RETAIN (e.g. use of --time_file, use of --use_log_time.)

  2. Again, you need the "visit file" and "label file" prepared in the same way. This time, however, you do not need to follow the ".train", ".valid", ".test" rule. The testing script will try to load the file name as given.

  3. You also need the mapping information between the actual string medical codes and their integer codes. (e.g. "Hypertension" is mapped to 24) This file (let's call this "mapping file") need to be a Python cPickled dictionary where the keys are the string medical codes and the values are the corresponding intergers. (e.g. The mapping file generated by process_mimic.py is the ".types" file) This file is required to print the contributions of each medical code in a user-friendly format.

  4. For the additional options such as --time_file or --use_log_time, you should use exactly the same configuration with which you trained the model. For more detailed information, use "--help" option.

  5. The minimum input to run the testing script is the "model file", "visit file", "label file", "mapping file", and "output file". "output file" is where the contributions will be stored. python test_retain.py <model file> <visit file> <label file> <mapping file> <output file>

retain's People

Contributors

mp2893 avatar davedecaprio avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.