Code Monkey home page Code Monkey logo

qcnet's Introduction

QCNet: An Elegant, Performant, And Scalable Framework For Marginal/Joint Multi-Agent Trajectory Prediction

Paper | Technical Report | YouTube | Twitter | 微信公众号 | 經濟日報

This repository is the official implementation of the CVPR 2023 paper: Query-Centric Trajectory Prediction.

Authors: Zikang Zhou, Jianping Wang, Yung-Hui Li, Yu-Kai Huang

Rank 1st on Argoverse 1 Single-Agent Motion Forecasting Benchmark
Rank 1st on Argoverse 2 Single-Agent Motion Forecasting Benchmark
Rank 1st on Argoverse 2 Multi-Agent Motion Forecasting Benchmark
Champion of Argoverse 2 Multi-Agent Motion Forecasting Challenge at CVPR 2023 Workshop on Autonomous Driving (WAD)

Table of Contents

News

[2023/07/04] The checkpoint for Argoverse 2 marginal prediction is released. Give it a try!
[2023/06/29] The code for Argoverse 2 marginal prediction is released. Enjoy it:)
[2023/06/18] QCNeXt, the extended version of QCNet, has won the championship of Argoverse 2 Multi-Agent Motion Forecasting Challenge at CVPR 2023 Workshop on Autonomous Driving (WAD).
[2023/02/28] QCNet is accepted by CVPR 2023.

Highlights

  • Scene encoder with roto-translation invariance in space: enable accurate multi-agent prediction fundamentally
  • Scene encoder with translation invariance in time: support streaming processing theoretically
  • Two-stage DETR-like trajectory decoder: facilitate multimodal and long-term prediction

Getting Started

Step 1: clone this repository:

git clone https://github.com/ZikangZhou/QCNet.git && cd QCNet

Step 2: create a conda environment and install the dependencies:

conda env create -f environment.yml
conda activate QCNet

Alternatively, you can configure the environment in your favorite way. Installing the latest version of PyTorch, PyG, and PyTorch Lightning should work well.

Step 3: install the Argoverse 2 API and download the Argoverse 2 Motion Forecasting Dataset following the Argoverse 2 User Guide.

Training & Evaluation

Training

The training process consumes ~160G GPU memory. For example, you can train the model on 8 NVIDIA GeForce RTX 3090:

python train_qcnet.py --root /path/to/dataset_root/ --train_batch_size 4 --val_batch_size 4 --test_batch_size 4 --devices 8 --dataset argoverse_v2 --num_historical_steps 50 --num_future_steps 60 --num_recurrent_steps 3 --pl2pl_radius 150 --time_span 10 --pl2a_radius 50 --a2a_radius 50 --num_t2m_steps 30 --pl2m_radius 150 --a2m_radius 150

Note 1: when running the training script for the first time, it will take several hours to preprocess the data.

Note 2: during training, the checkpoints will be saved in lightning_logs/ automatically.

Note 3: you can adjust the batch size and the number of devices. To reproduce the results, you should ensure the total batch size to be 32. If you don't have sufficient computing resource for training, you can adjust some hyperparameters, e.g., reducing the radius and the number of layers.

Validation

To evaluate on the validation set:

python val.py --model QCNet --root /path/to/dataset_root/ --ckpt_path /path/to/your_checkpoint.ckpt

Testing

To generate the prediction results on the test set:

python test.py --model QCNet --root /path/to/dataset_root/ --ckpt_path /path/to/your_checkpoint.ckpt

Submit the generated .parquet file to the Argoverse 2 leaderboard and achieve SOTA immediately!

Pretrained Models & Results

Quantitative Results

Model Dataset Split Checkpoint minFDE (K=6) minFDE (K=1) minADE (K=6) minADE (K=1) MR (K=6) MR (K=1) brier-minFDE (K=6)
QCNet AV2 Val QCNet_AV2 1.25 4.32 0.72 1.69 0.16 0.58 1.87
QCNet AV2 Test QCNet_AV2 1.24 4.31 0.64 1.70 0.15 0.58 1.86

The performance is slightly better than that reported in the paper due to some incremental updates since I finished the paper:)

Qualitative Results

Citation

If you found this repository useful, please consider citing our work:

@inproceedings{zhou2023query,
  title={Query-Centric Trajectory Prediction},
  author={Zhou, Zikang and Wang, Jianping and Li, Yung-Hui and Huang, Yu-Kai},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2023}
}
@article{zhou2023qcnext,
  title={QCNeXt: A Next-Generation Framework For Joint Multi-Agent Trajectory Prediction},
  author={Zhou, Zikang and Wen, Zihao and Wang, Jianping and Li, Yung-Hui and Huang, Yu-Kai},
  journal={arXiv preprint arXiv:2306.10508},
  year={2023}
}

This repository is developed based on our previous codebase HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction. Please also consider citing:

@inproceedings{zhou2022hivt,
  title={HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction},
  author={Zhou, Zikang and Ye, Luyao and Wang, Jianping and Wu, Kui and Lu, Kejie},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2022}
}

License

This repository is licensed under Apache 2.0.

qcnet's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

qcnet's Issues

agent states是直接由数据集Argoverse 2提供的吗?

@ZikangZhou Hi,论文里面说“During online running, the perception module supplies a stream of agent states to the prediction
module at a fixed interval, where each agent state is associated with its spatial-temporal position and geometric attributes”,但是在源码里面并没有找到“the perception module”,请问“a stream of agent states ”是直接由数据集Argoverse 2提供的吗?

Reusing of agent state/ map polygon encodings in subsequent observation windows

The paper states:

Benefiting from modeling in local reference frames, the embedding of each agent state/map polygon has only one instance and can be reused in the subsequent observation windows.

As the Argoverse 2 dataset contains independent scenarios as samples, the scene features are encoded from scratch for every sample and it is not possible to reuse any embeddings. If I want to apply your model on a real life use case where the next sample is the next observation window in time, how would I implement this reusing of the agent state/ map polygon encodings?

Need a new conda environment?

@ZikangZhou Hello,thanks for your excellent work. I would like to ask a question, when configuring the environment, in step 3 of "Getting Started", is this step directly installed in the conda environment created in step 2 of "Getting Started", or does it need to create a new conda environment ? Because step 1 in installing the Argoverse 2 API is "bash conda/install.sh && conda activate av2", which will create and activate a new conda environment?
2023-07-11 14-38-02屏幕截图

2023-07-11 14-51-44屏幕截图

Slightly poorer results from scratch training

Thank you for your wonderful work and code sharing, I would like to ask that when I train from scratch my results are two points worse than yours on minFDE on both the evaluation and validation sets, my training parameters are the same as yours, any tips on the training process please, looking forward to your answer!

visualization

请问有相关轨迹预测可视化的实现代码吗?

Questions about reuse the past "x_a" cache

周博士:
您好!

我在阅读您的代码时注意到了一些关于加快推理速度的技巧,特别是关于cache的使用以及如何处理帧与帧之间的agent编码问题。我有几个问题想要进一步了解和确认:
image

  • Temporal Attention中的KV Cache使用:在处理新帧时,您使用了历史的agent encodings。我理解这里应该使用的是KV Cache技巧,通过下三角矩阵的mask来保证推理结果的一致性。请问我的理解是否正确?

  • Agent数量与顺序不一致的处理:

    • 在连续的帧中,agent的数量可能不同,例如上一帧有5个agent,新一帧只有2个。如果past的维度是A, T-1, D,比如5, T-1, D,而新一帧的维度是2, 1, D,这种情况下应该如何合并这些数据?
    • 如果agent的传入顺序在不同帧中不同,比如上一帧是X, Y, Z,新一帧是Y, Z, X,这在合并时可能会造成问题。应该如何处理这种顺序变化?
  • 单帧Agent的edge和r计算:当只传入一个时间点的agent数据时,例如计算edge_index_t可能为空,这会导致无法计算rel_pos_t等相对关系,进而影响后续的Temporal Attention计算。我试了试每个agent的edge_index_t给的值为[[0],[0]],但attention的结果略有差异。这种差异可能是因为单个时间步的position_embedding和连续时间步骤的position_embedding不同导致的,在这种情况下应该如何处理能保证推理和训练结果一致性呢?

附上相关我理解的代码修改段落供参考,麻烦看看我的理解是否正确:

for i in range(self.num_layers):
    x_a = x_a.reshape(-1, self.hidden_dim)
    x_a = self.t_attn_layers[i]((x_a, x_a), r_t, edge_index_t, kv_cache = kv_cache)
    x_a = x_a.reshape(-1, 1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
    x_pl = x_pl.transpose(0, 1).reshape(-1, self.hidden_dim)
    x_a = self.pl2a_attn_layers[i]((x_pl, x_a), r_pl2a, edge_index_pl2a)
    x_a = self.a2a_attn_layers[i]((x_a, x_a), r_a2a, edge_index_a2a)
    x_a = x_a.reshape(1, -1, self.hidden_dim).transpose(0, 1)

if x_a_past is not None:
    x_a = torch.cat([x_a_past[1:, :, :], x_a], dim = 1) 
return x_a, kv_cache

期待您的回复和指导。

谢谢!

Results on Waymo dataset?

Hi, Zikang. Thanks for your nice work. May I ask have you tried the model on the waymo dataset? If there is a code for testing on the waymo dataset, could you share it with us? Thanks a lot!!!

QCNet Encoder vs QCNeXt Encoder

I am really impressed by the work. However, reading through QCNet paper and QCNeXt report, you have mentioned that the encoders used are the same. Whereas the Fig 1 of QCNeXt technical report shows an encoder that is different from the QCNet encoder. Also, the encoder shown in the QCNeXt, is kind of what is shown in the QCNet video presentation for the existing works that are computationally expensive. Can you please clarify for better understanding? Also any expected time frame for release of QCNeXt code?

How to load checkpoint and continue training?

I wonder how to load checkpoint and continue training. Because of the OOM (maybe at 30 epochs), I have to load checkpoint and continue training after OOM. I try to add ckpt_path = '....ckpt' (path from lightning_logs/checkpoint), but it seems useless. How should I solve this problem? Thanks a lot!

Multi-agent forecasting

I read your paper and I have some questions about the multi-agent forecasting. I would like to apply your model to a real life use-case and forecast trajectories for multiple agents.
Is this already possible with the code you provide? To my understanding, the decoder needs to know which target agent it should predict trajectories for. Can I just implement multiple decoder heads with shared encodings and how would I tell the decoder which agent it should predict trajectories for?

Thanks in advance!

关于agent坐标和车道环境的一些疑问

Hello,Dr Zhou,请问,1. 若想在现实中实现这个模型或HiVT模型,必须像argoverse数据集那样给出所有agent在空间中的绝对坐标吗?
2. 现有的车辆轨迹预测方法大多数都是在公开数据集中进行的,若想实现这些方法,是否也需要提前将环境信息(车道线,交通灯,人行道等)构建好?
期待您的解答,谢谢!!!

I need some help about Qualitative Results and validation

hello.

I proceeded from QCNet train to validation for path prediction.

Could you please provide the ipynb file and code for visual testing such as Qualitative Results?

And, I am trying to validate a model already trained with num_future_steps =60, num_historical_steps=50.
Can you please suggest the code to get minADE, minFDE for num_future_steps 0 to 10 or 0 to 20 in validation?

conda version

Hi, when I conduct conda env create -f environment.yml, there will be some version conflicts of conda.
4.5.4 and 4.12.0 both have problems, but 4.10.1 can succeed.
Can I know your conda version so that I can reproduce your great work in omy own computer.

test error

Hi, your work is great. I try to generate the prediction results on the test set by following your code:

python test.py --model QCNet --root /path/to/dataset_root/ --ckpt_path /path/to/your_checkpoint.ckpt

But there is some error like this:

Global seed set to 2023
Processing...
  0%|          | 0/24984 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "test.py", line 40, in <module>
    test_dataset = {
  File "/content/QCNet/datasets/argoverse_v2_dataset.py", line 150, in __init__
    super(ArgoverseV2Dataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)
  File "/usr/local/envs/QCNet/lib/python3.8/site-packages/torch_geometric/data/dataset.py", line 97, in __init__
    self._process()
  File "/content/QCNet/datasets/argoverse_v2_dataset.py", line 534, in _process
    self.process()
  File "/content/QCNet/datasets/argoverse_v2_dataset.py", line 190, in process
    map_data = read_json_file(map_path)
TypeError: object() takes no arguments

Wondering what's the reason and how to deal with it. Thank you very much.

Questions about result reproducibility.

@ZikangZhou Hi, I followed the script provided in your readme file to reproduce the results, but encountered two issues.
Firstly, when I set the batch size to 4 on my 3090 GPU, I encountered out of memory problem.
So I adjusted the batch size to 2, while keeping the other hyperparameters unchanged.
The resulting reproduction shows that minade=0.732 and minfde=1.294, which has a small deviation from the results reported in the paper.
However, there is a significant difference from the pre-trained model that you uploaded, where minade=1.25.
I would like to ask if the hyperparameters or other training configurations of the pre-trained model in this version have been changed?
Thank you very much.

ubuntu or windows

您好,请问这个是在ubuntu系统下运行的吗,在windows下可以运行吗

Is this what the training should look like ?

Hi I have been training the model for two days now on 2 x V100 GPUs 32 GB. This is what the output at my terminal looks like. Can you comment on if the training is going in the right direction? I see some losses increasing here. Also how many epochs is the training suppose to last for?

image

Training cost too much

@ZikangZhou Hi Zhou! I noticed that when I train this code, it utilizes 32GB of GPU memory per GPU instead of the 20GB mentioned in the README.md. Could you please explain what might be causing this discrepancy in this repository?

Inference time (ms) and GFLOPs of QCNet and QCNeXT

Hi, can you share the Inference time (ms) and GFLOPs of QCNet and QCNeXT projects. Something very similar to this paper that reported these metrics for other frameworks in the paper titled :

ProphNet: Efficient Agent-Centric Motion Forecasting with Anchor-Informed Proposals

image

Brier score is oscillating

The Brier score in my training is not decreasing while all the losses and the socre seems to be converging to the reported values. The corresponding value after 68 hours of training on A100, with 100GB RAM.

The training hyperparameters are:
--num_workers=32 --train_batch_size=16 --val_batch_size=16 --test_batch_size=16 --max_epoch=50 --pl2pl_radius=150 --time_span=10 --pl2a_radius=50 --a2a_radius=50 --num_t2m_steps=30 --pl2m_radius=150 --a2m_radius=150

Current training status;
Epoch: 32
Training duration: 68 hours

Brier: 0.61 (osciallting near this value since the begining of the training)
MR (k=6): 0.1824
minADE (k=6): 0.751
minFDE (k=6): 1.359

I am also attaching the Tensorboard checkpoint for further analysis. Please let me know if anything else is required.
Checkpoint.zip

Using PyG TemporalDataLoader

This is more a question than an issue.

Was there an interest to try the PyG TemporalDataLoader to increase performance?
I believe there is some information that can be obtained from feeding the temporally linear data together.

I don't know much of Argoverse2, but I know Argoverse1 sequences were not temporally linear so there was no possibility to do this in HiVT.

No map data demo

The paper is claimed to be capable to predict agent trajectories without hd-maps data.
Could you provide some demo on how the model predicts trajectories in such scenario?

Visualization methods

Hi Zikang,

Cheers for your wonderful work! Is it possible for you guys to release the visualization code as well? As I'm not quite familiar with the argoverse dataset and its toolkit, it would be great if you could share your visualization methods. Thanks!

Training time

Thank you for your amazing work and contribution to the trajectory prediction community, since the model is big and requires large resources to train, how long did you train the model with those resources?

av2-api有安装教程嘛?

av2-api有安装教程嘛?我之前使用过程中有找不到upath的问题。from upath import UPath找不到这个UPath,网上也没找到对应解决方案。

av2-api有安装教程嘛?

av2-api有安装教程嘛?我之前使用过程中有找不到upath的问题。from upath import UPath找不到这个UPath,网上也没找到对应解决方案。

Problems about training

Hi,Doctor Zhou:
Well done work!But I have some problems.
1.What causes the OOM problem during training?Recently,someone says that Argoverse2 dataset has different size of map representation input.When training,random combination of large size of samples in a batch may leads to OOM problem.So how can I deal with the problem with a limited memory of GPU.
2.In one of QCNet improvement,after transpose and rotation invariance processing,does each agent in each timestep has a local reference and each agent in each timestep can be regarded as a node in GAT?The edge is the relative position relations of two agents' local reference?I noticed that the relative position relations of two agents' local reference is represented by the relative location,relative direction,relative orientation and relative time-order.I wonder if the four items can uniquely determined the relative relations of two local references?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.