Code Monkey home page Code Monkey logo

contextvae's Introduction

Context-Aware Timewise VAEs for Real-Time Vehicle Trajectory Prediction

Official implementation for Context-Aware Timewise VAEs for Real-Time Vehicle Trajectory Prediction.

Accepted by IEEE Robotics and Automation Letters. [arXiv][YouTube]

Dependencies

  • Pytorch 1.11
  • Numpy 1.21

We recommend to install all the requirements through Conda by

$ conda create --name <env> --file requirements.txt -c pytorch -c conda-forge

Code Usage

Train a model from scratch

$ python main.py \
    --train <train_data_dir> --test <test_data_dir> \
    --config <train_config_file> --ckpt <checkpoint_dir>

Evaluate a trained model

$ python main.py \
    --test <test_data_dir> \
    --config <eval_config_file> --ckpt <checkpoint_dir>

Data Preprocessing

See scripts/README.md for details of data preprocessing.

Train from scratch

We provide our configure files in config folder. To reproduce the results, please run

# nuScenes
$ python main.py \
    --train <data_dir>/train --test <data_dir>/val --map_dir <data_dir>/map \
    --config config/nuscenes_train.py --ckpt <checkpoint_dir>

# Lyft
$ python main.py \
    --train <data_dir>/train/0 --test <data_dir>/validate/0 --map_dir <data_dir>/map \
    --config config/lyft_train.py --ckpt <checkpoint_dir> \
    --rank 0 --workers 4
$ python main.py \
    --train <data_dir>/train/1 --test <data_dir>/validate/1 --map_dir <data_dir>/map \
    --config config/lyft_train.py --ckpt <checkpoint_dir> \
    --rank 1 --workers 4 --master_addr <master_addr>
$ python main.py \
    --train <data_dir>/train/2 --test <data_dir>/validate/2 --map_dir <data_dir>/map \
    --config config/lyft_train.py --ckpt <checkpoint_dir> \
    --rank 2 --workers 4 --master_addr <master_addr>
$ python main.py \
    --train <data_dir>/train/3 --test <data_dir>/validate/3 --map_dir <data_dir>/map \
    --config config/lyft_train.py --ckpt <checkpoint_dir> \
    --rank 3 --workers 4 --master_addr <master_addr>

# Waymo
$ python main.py \
    --train <data_dir>/training/0 --map_dir <data_dir>/map/training/0 \
    --test <data_dir>/validation/0  --test_map_dir <data_dir>/map/validation/0\
    --config config/waymo_train.py --ckpt <checkpoint_dir> \
    --rank 0 --workers 8
...
$ python main.py \
    --train <data_dir>/training/7 --map_dir <data_dir>/map/training/7 \
    --test <data_dir>/validation/7  --test_map_dir <data_dir>/map/validation/7\
    --config config/waymo_train.py --ckpt <checkpoint_dir> \
    --rank 7 --workers 8 --master_addr <master_addr>

We use distributed training for Lyft and Waymo datasets with 4 and 8 worker machines respectively. (cf. https://pytorch.org/docs/stable/distributed.html for Pytorch distributed training.)

All training was done with A100 GPUs for Lyft and Waymo datasets and a V100 GPU for nuScenes.

Evaluation

We also provided our pre-trained models in Release.

To reproduce the testing results, please run

# nuScenes
$ python main.py \
    --test <data_dir>/val --map_dir <data_dir>/map \
    --config config/nuscenes_eval.py --ckpt models/nuscenes_res18

# Lyft
$ python main.py \
    --test <data_dir>/validate/0 --map_dir <data_dir>/map \
    --config config/lyft_eval.py --ckpt models/lyft_res152 \
    --rank 0 --workers 4
$ python main.py \
    --test <data_dir>/validate/1 --map_dir <data_dir>/map \
    --config config/lyft_eval.py --ckpt models/lyft_res152 \
    --rank 1 --workers 4 --master_addr <master_addr>
$ python main.py \
    --test <data_dir>/validate/2 --map_dir <data_dir>/map \
    --config config/lyft_eval.py --ckpt models/lyft_res152 \
    --rank 2 --workers 4 --master_addr <master_addr>
$ python main.py \
    --test <data_dir>/validate/3 --map_dir <data_dir>/map \
    --config config/lyft_eval.py --ckpt models/lyft_res152 \
    --rank 3 --workers 4 --master_addr <master_addr>

# Waymo
$ python main.py \
    --test <data_dir>/validation/0 --map_dir <data_dir>/map/validation/0 \
    --config config/waymo_eval.py --ckpt models/waymo_m2 \
    --rank 0 --workers 8
...
$ python main.py \
    --test <data_dir>/validation/7 --map_dir <data_dir>/map/validation/7 \
    --config config/waymo_eval.py --ckpt models/waymo_m2 \
    --rank 7 --workers 8 --master_addr <master_addr>

Citation

@article{contextvae2023,
    title={Context-Aware Timewise {VAE}s for Real-Time Vehicle Trajectory Prediction},
    author={Xu, Pei and Hayet, Jean-Bernard and Karamouzas, Ioannis},
    journal={IEEE Robotics and Automation Letters},
    year={2023},
    volume={8},
    number={9},
    pages={5440-5447},
    doi={10.1109/LRA.2023.3295990}
}

contextvae's People

Contributors

xupei0610 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.