Code Monkey home page Code Monkey logo

pandemicllm's Introduction

PandemicLLM: Adapting Large Language Models to Forecast Pandemics in Real-time: A COVID-19 Case Study

Hongru Du*, Jianan Zhao*, Yang Zhao*, Shaochong Xu, Xihong Lin, Yiran Chen#, Lauren M. Gardner#, and Hao (Frank) Yang#.

* Equal Contribution, # Corresponding Authors

Cow1

Steps to reproduce

Python Environment

We conduct experiments under CUDA 12.1 and Ubuntu 22.04 on Nvidia A100 GPU.

Create conda environment

conda create -n PLLM python=3.10
source activate PLLM

Install the related packages.

pip install -r requirements.txt
conda install mpi4py

Set up Environment Variables

We recommend to use wandb to monitor and manage the training process. Before runing the training script, make sure to set up the wandb environment properly. You can also skip wandb by setting the use_wandb=False.

The HF_ACCESS_TOKEN is also necessary for accessing LLaMA models.

mkdir configs/user
echo "
# @package _global_
env:
  vars:
    WANDB_API_KEY: YOUR_WANDB_API_KEY
    HF_ACCESS_TOKEN: YOUR_HF_ACCESS_TOKEN
" >> configs/user/env.yaml

Supervised Fine-tuning (SFT)

CovidLLM supports instruction fine-tuning a LLM on graph. An RNN is used to map the continuous sequence to text space (as tokens). We recommend to use BF16 for stable training. To reproduce our results, you should run the following scripts under the same environment we listed above:

PandemicLLM-7B 1-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-7b lr=2e-05 save_model=False seed=2024 splits_type=base_splits target=t1 total_steps=1501 use_cont_fields=True use_deepspeed=True use_wandb=True

PandemicLLM-7B 3-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-7b lr=2e-05 save_model=False seed=2024 splits_type=sta_aug_splits target=t3 total_steps=1501 use_cont_fields=True use_deepspeed=True use_wandb=True

PandemicLLM-13B 1-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-13b lr=1e-05 save_model=False seed=2023 splits_type=sta_aug_splits target=t1 total_steps=1501 use_cont_fields=True use_deepspeed=True use_wandb=True

PandemicLLM-13B 3-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-13b lr=2e-05 save_model=False seed=2023 splits_type=base_splits target=t3 total_steps=1501 use_cont_fields=True use_deepspeed=True use_wandb=True

PandemicLLM-70B 1-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-70b lr=1e-05 save_model=False seed=2023 splits_type=sta_dy_aug_splits target=t1 total_steps=1501 use_cont_fields=True use_deepspeed=True use_int4=True use_wandb=True

PandemicLLM-70B 3-week

cd src/scripts
python run_covid_llm_sft.py data_file=processed_v5_4.pkl data_file=processed_v5_4.pkl eq_batch_size=4 in_weeks=3 llm_base_model=llama2-70b lr=2e-05 save_model=False seed=2023 splits_type=sta_dy_aug_splits target=t3 total_steps=1501 use_cont_fields=True use_deepspeed=True use_int4=True use_wandb=True

Dataset

PandemicLLM is fine-tuned using multiple disparate categories of data including spatial (static), epidemiological time series (dynamic), genomic surveillance (dynamic), and public health policy data (dynamic). Our data covers all 50 states in the United States, ensuring a comprehensive nationwide scope for our study. All the spatial data are available at the state resolution, while all the time-varying data are available at the weekly resolution. Please check our paper for more details about the data.

Load the Dataset

import pickle
path = '/data/processed_v5_4.pkl'
data = pickle.load(open(path, 'rb'))

Data Columns

data: raw data sta_dy_aug_data: data with static and dynamic augmentation.

base_splits, sta_aug_splits, dy_aug_splits, sta_dy_aug_splits: indexes for sta_dy_aug_data, representing for raw data, data with static augmentation, data with dynamic augmentation, and data with static and dynamic augmentation respectively.

label_info, mse_val_map: information used for training

Credits

This work is supported by NSF Award ID 2229996, NSF Award ID 2112562 and ARO W911NF-23-2-0224.

pandemicllm's People

Contributors

miemieyanga avatar hongru94 avatar andyjzhao avatar

Stargazers

Xu Mingda avatar Minhua Lin avatar  avatar DaiShaoqing avatar  avatar  avatar Americo Cunha avatar  avatar  avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.