Code Monkey home page Code Monkey logo

cure-lab / deciwatch Goto Github PK

View Code? Open in Web Editor NEW
172.0 9.0 15.0 28.89 MB

[ECCV 2022] Official implementation of the paper "DeciWatch: A Simple Baseline for 10x Efficient 2D and 3D Pose Estimation"

License: Apache License 2.0

Shell 0.14% Python 99.86%
2d-human-pose 3d-pose-estimation body-reconstruction efficient-inference human-pose-estimation 3d-body-recovery deep-learning efficiency efficient-neural-networks pose-estimation pytorch eccv eccv2022

deciwatch's Introduction

DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation (ECCV 2022)

This repo is the official implementation of "DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation". [Paper] [Project]

Update

  • Add failure cases and more analyses in project page

  • Provide different sample interval checkpoints

  • Support DeciWatch in MMHuman3D Release v0.7.0 as a 10x speed up strategy!

  • Clean version is released! It currently includes code, data, log and models for the following tasks:

  • 2D human pose estimation

  • 3D human pose estimation

  • Body recovery via a SMPL model

TODO

Description

This paper proposes a simple baseline framework for video-based 2D/3D human pose estimation that can achieve 10 times efficiency improvement over existing works without any performance degradation, named DeciWatch. Unlike current solutions that estimate each frame in a video, DeciWatch introduces a simple yet effective sample-denoise-recover framework that only watches sparsely sampled frames, taking advantage of the continuity of human motions and the lightweight pose representation. Specifically, DeciWatch uniformly samples less than 10% video frames for detailed estimation, denoises the estimated 2D/3D poses with an efficient Transformer architecture, and then accurately recovers the rest of the frames using another Transformer-based network. Comprehensive experimental results on three video-based human pose estimation, body mesh recovery tasks and efficient labeling in videos with four datasets validate the efficiency and effectiveness of DeciWatch.

Major Features

  • Model training and evaluation for 2D pose, 3D pose, and SMPL body representation
  • Supporting four popular datasets (Human3.6M, 3DPW, AIST++, Sub-JHMDB) and providing cleaned data of five popular pose estimation backbones(FCN, SPIN, EFT, PARE, SimplePose)
  • Versatile visualization toolbox with comparision of input (backbone estimator results) and output(DeciWatch results)

Visualize 2D poses on Sub-JHMDB dataset: visualize of Sub-JHMDB 2D Simplepose

Visualize 3D poses on AIST++ dataset: visualize of AIST++ 3D SPIN

Visualize SMPL on 3DPW dataset: visualize of 3DPW SMPL Pare

Getting Started

Environment Requirement

DeciWatch has been implemented and tested on Pytorch 1.10.1 with python >= 3.6. It supports both GPU and CPU inference.

Clone the repo:

git clone https://github.com/cure-lab/DeciWatch.git

We recommend you install the requirements using conda:

# conda
source scripts/install_conda.sh

Prepare Data

All the data used in our experiment can be downloaded here.

Google Drive

Baidu Netdisk

Valid data includes:

Dataset Pose Estimator 3D Pose 2D Pose SMPL
Sub-JHMDB SimplePose
3DPW EFT
3DPW PARE
3DPW SPIN
Human3.6M FCN
AIST++ SPIN

Please refer to doc/data.md for detailed data information and data preparing.

Training

Note that the training and testing datasets should be downloaded and prepared before training.

You may refer to doc/training.md for more training details.

Run the commands below to start training:

python train.py --cfg [config file] --dataset_name [dataset name] --estimator [backbone estimator you use] --body_representation [smpl/3D/2D] --sample_interval [sample interval N]

For example, you can train on 3D position representation of the 3DPW dataset using the backbone estimator SPIN with a sample interval N=10 (sampling ratio=10%) by:

python train.py --cfg configs/config_pw3d_spin.yaml --dataset_name pw3d --estimator spin --body_representation 3D --sample_interval 10

Evaluation (Take a 10% sampling ratio as an example)

Noted that although our main contribution is the high efficiency improvement, using DeciWatch as post processing is also helpful for accuracy and smoothness improvement.

You may refer to doc/evaluate.md for evaluate details on all sampling ratios.

Results on 2D Pose:

Dataset Estimator PCK 0.05 (Input/Output):arrow_up: PCK 0.1 (Input/Output):arrow_up: PCK 0.2 (Input/Output):arrow_up: Checkpoint
Sub-JHMDB simplepose 57.30%/79.44% 81.61%/94.05% 93.94%/98.75% Baidu Netdisk / Google Drive

Results on 3D Pose:

Dataset Estimator MPJPE (Input/Output):arrow_down: Accel (Input/Output):arrow_down: Checkpoint
3DPW SPIN 96.92/93.34 34.68/7.06 Baidu Netdisk / Google Drive
3DPW EFT 90.34/89.02 32.83/6.84 Baidu Netdisk / Google Drive
3DPW PARE 78.98/77.16 25.75/6.90 Baidu Netdisk / Google Drive
AIST++ SPIN 107.26/71.27 33.37/5.68 Baidu Netdisk / Google Drive
Human3.6M FCN 54.56/52.83 19.18/1.47 Baidu Netdisk / Google Drive

Results on SMPL-based Body Recovery:

Dataset Estimator MPJPE (Input/Output):arrow_down: Accel (Input/Output):arrow_down: MPVPE (Input/Output):arrow_down: Checkpoint
3DPW SPIN 100.13/97.53 35.53/8.38 114.39/112.84 Baidu Netdisk / Google Drive
3DPW EFT 91.60/92.56 33.57/8.75 110.34/109.27 Baidu Netdisk / Google Drive
3DPW PARE 80.44/81.76 26.77/7.24 94.88/95.68 Baidu Netdisk / Google Drive
AIST++ SPIN 108.25/82.10 33.83/7.27 137.51/106.08 Baidu Netdisk / Google Drive

Quick Demo

Here, we only provide demo visualization based on offline processed detected poses of specific datasets(e.g. AIST++, Human3.6M, 3DPW, and Sub-JHMDB). To visualize on arbitrary given video, please refer to the inference/demo of MMHuman3D.

Run the commands below to visualize demo:

python demo.py --cfg [config file] --dataset_name [dataset name] --estimator [backbone estimator you use] --body_representation [smpl/3D/2D] --sample_interval [sample interval N]

You are supposed to put corresponding images with the data structure:

|-- data
    |-- videos
        |-- pw3d 
            |-- downtown_enterShop_00
                |-- image_00000.jpg
                |-- ...
            |-- ...
        |-- jhmdb
            |-- catch
            |-- ...
        |-- aist
            |-- gWA_sFM_c01_d27_mWA2_ch21.mp4
            |-- ...
        |-- ...

For example, you can visualize results of 3D position representation on the 3DPW dataset using the backbone estimator SPIN with a sample interval N=10 (sampling ratio=10%) by:

python demo.py --cfg configs/config_pw3d_spin.yaml --dataset_name pw3d --estimator spin --body_representation 3D --sample_interval 10

Please refer to the dataset website for the raw images. You may change the config in lib/core/config.py for different visualization parameters.

You can refer to doc/visualize.md for visualization details.

Citing DeciWatch

If you find this repository useful for your work, please consider citing it as follows:

@inproceedings{zeng2022deciwatch,
  title={DeciWatch: A Simple Baseline for 10x Efficient 2D and 3D Pose Estimation},
  author={Zeng, Ailing and Ju, Xuan and Yang, Lei and Gao, Ruiyuan and Zhu, Xizhou and Dai, Bo and Xu, Qiang},
  booktitle={European Conference on Computer Vision},
  year={2022},
  organization={Springer}
}

Please remember to cite all the datasets and backbone estimators if you use them in your experiments.

Acknowledgement

Many thanks to Xuan Ju for her great efforts to clean almost the original code!!!

License

This code is available for non-commercial scientific research purposes as defined in the LICENSE file. By downloading and using this code you agree to the terms in the LICENSE. Third-party datasets and software are subject to their respective licenses.

deciwatch's People

Contributors

ailingzengzzz avatar juxuan27 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deciwatch's Issues

How to add custom data?

Thanks for your great work!
For the human3.6M dataset, my goal is to change the backbone network fcn (to e.g. videopose 3d), how to get the .npz files like yours?
image

RuntimeError: "baddbmm_cuda" not implemented for 'Int' 在 GPU 上无法运行

(base) sujia@cupt-System-Product-Name:~/hf/DeciWatch$ cd /home/sujia/hf/DeciWatch ; /usr/bin/env /home/sujia/anaconda3/bin/python /home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 43147 -- /home/sujia/hf/DeciWatch/train.py
Namespace(cfg='/home/sujia/hf/DeciWatch/configs/config_h36m_fcn_3D.yaml', dataset_name='h36m', estimator='fcn', body_representation='3D', sample_interval=10)

Seed value for the experiment is 4321
GPU name -> NVIDIA GeForce RTX 3090
GPU feat -> _CudaDeviceProperties(name='NVIDIA GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)
{'BODY_REPRESENTATION': '3D',
'CUDNN': CfgNode({'BENCHMARK': True, 'DETERMINISTIC': False, 'ENABLED': True}),
'DATASET': {'AIST': {'DETECTED_PATH': 'data/detected_poses/aist',
'GROUND_TRUTH_PATH': 'data/groundtruth_poses/aist',
'KEYPOINT_NUM': 14,
'KEYPOINT_ROOT': [2, 3]},
'H36M': {'DETECTED_PATH': 'data/detected_poses/h36m',
'GROUND_TRUTH_PATH': 'data/groundtruth_poses/h36m',
'KEYPOINT_NUM': 17,
'KEYPOINT_ROOT': [0]},
'JHMDB': {'DETECTED_PATH': 'data/detected_poses/jhmdb',
'GROUND_TRUTH_PATH': 'data/groundtruth_poses/jhmdb',
'KEYPOINT_NUM': 15,
'KEYPOINT_ROOT': [2]},
'PW3D': {'DETECTED_PATH': 'data/detected_poses/pw3d',
'GROUND_TRUTH_PATH': 'data/groundtruth_poses/pw3d',
'KEYPOINT_NUM': 14,
'KEYPOINT_ROOT': [2, 3]}},
'DATASET_NAME': 'h36m',
'DEBUG': True,
'DEVICE': 'cuda',
'ESTIMATOR': 'fcn',
'EVALUATE': {'DENOISE': False,
'INTERP': 'linear',
'PRETRAINED': 'data/checkpoints/h36m_fcn_3d/checkpoint.pth.tar',
'RELATIVE_IMPROVEMENT': False,
'ROOT_RELATIVE': True,
'SLIDE_WINDOW_STEP_Q': 1,
'SLIDE_WINDOW_STEP_SIZE': 10},
'EXP_NAME': 'h36m_fcn',
'LOG': CfgNode({'NAME': ''}),
'LOGDIR': 'results/08-08-2022_18-44-22_h36m_fcn',
'LOSS': CfgNode({'LAMADA': 1.0, 'W_DENOISE': 1.0}),
'MODEL': {'DECODER': 'transformer',
'DECODER_EMBEDDING_DIMENSION': 128,
'DECODER_HEAD': 4,
'DECODER_INTERP': 'linear',
'DECODER_RESIDUAL': True,
'DECODER_TOKEN_WINDOW': 5,
'DECODER_TRANSFORMER_BLOCK': 5,
'DROPOUT': 0.1,
'ENCODER_EMBEDDING_DIMENSION': 128,
'ENCODER_HEAD': 4,
'ENCODER_RESIDUAL': True,
'ENCODER_TRANSFORMER_BLOCK': 5,
'INTERVAL_N': 10,
'NAME': '',
'SAMPLE_TYPE': 'uniform',
'SLIDE_WINDOW': True,
'SLIDE_WINDOW_Q': 10,
'SLIDE_WINDOW_SIZE': 101,
'TYPE': 'network'},
'OUTPUT_DIR': 'results',
'SAMPLE_INTERVAL': 10,
'SEED_VALUE': 4321,
'SMPL_MODEL_DIR': 'data/smpl/',
'TRAIN': {'BATCH_SIZE': 1024,
'EPOCH': 20,
'LR': 0.001,
'LRDECAY': 0.95,
'PRE_NORM': True,
'RESUME': None,
'USE_6D_SMPL': False,
'USE_SMPL_LOSS': False,
'VALIDATE': True,
'WORKERS_NUM': 0},
'VIS': {'END': 1000,
'INPUT_VIDEO_NUMBER': 143,
'INPUT_VIDEO_PATH': 'data/videos/',
'OUTPUT_VIDEO_PATH': 'demo/',
'START': 0}}
#############################################################
You are loading the [training set] of dataset [h36m]
You are using pose esimator [fcn]
The type of the data is [3D]
The frame number is [1559752]
The sequence number is [600]
#############################################################
#############################################################
You are loading the [testing set] of dataset [h36m]
You are using pose esimator [fcn]
The type of the data is [3D]
The frame number is [543344]
The sequence number is [236]
#############################################################

Traceback (most recent call last):
File "/home/sujia/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/sujia/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/main.py", line 39, in
cli.main()
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="main")
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/sujia/.vscode-server/extensions/ms-python.python-2022.12.0/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/home/sujia/hf/DeciWatch/train.py", line 108, in
main(cfg)
File "/home/sujia/hf/DeciWatch/train.py", line 95, in main
Trainer(train_dataloader=train_loader,
File "/home/sujia/hf/DeciWatch/lib/core/trainer.py", line 67, in run
self.train()
File "/home/sujia/hf/DeciWatch/lib/core/trainer.py", line 113, in train
predicted_3d_pos, denoised_3d_pos = self.model(
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 158, in forward
self.recover, self.denoise = self.transformer.forward(
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 267, in forward
output = self.decode(mem, encoder_mask, encoder_pos_embed, trans_tgt,
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 287, in decode
hs = self.decoder(tgt,
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 343, in forward
output = layer(output,
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 536, in forward
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
File "/home/sujia/hf/DeciWatch/lib/models/deciwatch.py", line 507, in forward_pre
tgt2 = self.self_attn(q,
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1153, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 5179, in multi_head_attention_forward
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
File "/home/sujia/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 4852, in _scaled_dot_product_attention
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
RuntimeError: "baddbmm_cuda" not implemented for 'Int'

3D custom data format

Hi,
I am trying to run inference on 3D detection data, I have the body coordinates of people. What preprocessing do I need to do for inference? I normalised the data wrt the hip centre. I tried to visualise prepared data and the person was upside down, I am little confused.
download

MMpose integration

Great work!

Just wanted to ask if it's still planned to integrate DeciWatch in mmpose? There is a stale PR there for some months now.

Train on custom data?

Is there a way for me to train on custom data? What format would it need to be in?

Questions about network structure

According to the setting of the paper, for 3dhpe, we can only input 3d poses, and output smooth 3D poses through DeciWatch network. Can we directly input 2D poses, and use the network to lifting the dimension and reduce the noise? Looking forward to your reply:)

Can Deciwatch be executed online?

Both smoothnet and deciwatch are offline attitude estimation, and the design method determines that they cannot be performed in real time.

数据处理的问题

您好,我想使用自定义human3.6M的 npz,您在readme 中写道需要使用 root-relative 的数据,但是对原文代码 debug 的时候,里面的数据并不是 root-relative 的
image
请问最后应该保持哪种格式?
image
这个图中上面的表示我已经处理好的 root-relative 的关节数据
下面则是您 npz 中的数据?感谢回答

Bad deciwatch output

Hi, I have 3d joints data of person in inches, I converted this data to meters and trying to run deciwatch. All results are good except the deciwatch. I also tried to normalized data wrt the hip mid-point but the results look the same. Any Idea what I am doing wrong ?
result

SmoothNet vs. Transformer as the denoise and recover net

Dear authors,

Thanks for your amazing work and releasing the code!
In your other work SmoothNet, you showed that temporal-only network is superior to a transformer. However, here you use a vanilla transformer module as the denoise and recover net. In theory, these two networks can also be simply replaced by two SmoothNets. I am wondering have you done these experiments before? And what is your insight into these?
Thanks!

Best,
Xianghui

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.