Code Monkey home page Code Monkey logo

sgnn's Introduction

Learning Physical Dynamics with Subequivariant Graph Neural Networks (NeurIPS 2022)

Jiaqi Han, Wenbing Huang, Hengbo Ma, Jiachen Li, Joshua B. Tenenbaum, Chuang Gan

License: MIT

[Project Page] [Paper]

We propose Subequivariant Graph Neural Network (SGNN) that jointly leverages object-aware information as well as subequivariance, a novel concept that relaxes $E(3)$-equivariance constraint in the presence of external fields like gravity. SGNN is demonstrated to perform promisingly in learning physical dynamics on Physion and RigidFall dataset, while being generalizable and data-efficient.

SGNN_arch

Dependencies

python==3.8
scikit-learn==0.24.2
networkx==2.5.1
torch==1.8.0
torch-geometric==2.0.1

These are the basic requirements for running the model of SGNN.

Acknowledgment: This repository is developed based on Physion and RigidFall. To run the experiments or conduct visualization on these datasets, more packages will be required. Please refer to Physion and RigidFall for more details.

Data Preparation

Physion

We use the official code of Physion for the particle-based methods. This official repository contains the detailed guidance of how to retrieve and pre-process the Physion dataset into the desirable format for model training and evaluation.

RigidFall

We use the official code of VGPL, which includes instructions of downloading RigidFall and leveraging the data to train and test the model.

Training and Evaluation

Firstly, make sure the dataset is ready and placed as instructed by the official repositories of Physion or RigidFall. At this time, all placeholders like YOUR_DATA_DIR in the code here in this repository should be properly replaced as your data directory for the preprocessed data.

Physion

1. Training.

In the Physion directory (cd Physion), simply use

bash scripts/train_sgnn.sh [SCENARIO] [GPU_ID]

where [SCENARIO] can be selected from Dominoes, Contain, Link, Drape, Support, Drop, Collide, Roll, and [GPU_ID] is the id of the gpu available, such as 0.

2. Evaluation.

bash scripts/eval_sgnn.sh [SCENARIO] [EPOCH] [ITER] [SCENARIO] [GPU_ID]

where [EPOCH] and [ITER] are used to select the checkpoint, and setting them as zeros will automatically select the checkpoint with best validation loss.

3. Visualization.

Similar to evaluation, simply run

bash scripts/vis_sgnn.sh [SCENARIO] [EPOCH] [ITER] [SCENARIO] [GPU_ID]

RigidFall

1. Training.

In the RigidFall directory (cd RigidFall), run

bash scripts/dynamics/train_RigidFall_SGNN.sh

2. Evaluation and visualization.

bash scripts/dynamics/eval_RigidFall_SGNN.sh

The --vispy option in eval_RigidFall_SGNN.sh is set to 1 if visualization is needed and 0 otherwise.

Visualization

An example of the comparison between baselines:

demo1

Generalization towards random rotation along the gravity axis:

rotation1

More visualizations are presented at our project page.

Citation

If you find our work helpful, please consider citing our work as:

@article{han2022learning,
  title={Learning Physical Dynamics with Subequivariant Graph Neural Networks},
  author={Han, Jiaqi and Huang, Wenbing and Ma, Hengbo and Li, Jiachen and Tenenbaum, Joshua B and Gan, Chuang},
  journal={arXiv preprint arXiv:2210.06876},
  year={2022}
}

Contact

If you have any questions, welcome to contact me at:

Jiaqi Han: [email protected]

sgnn's People

Contributors

hanjq17 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

oswaldxia

sgnn's Issues

IndexOutOfBoundError When Executing RigidFall Evaluation

Hi hanjq:

Thank you so much for your open-source code.
I am trying to replicate the experiment results of RigidFall dataset.
However, after training the model, when I am trying to evaluate on validation datasets, I encountered following errors:

model_kp #params: 1550625
Loading network from /home1/user/gnn_exp/exp_logs/SGNN/RigidFall/SGNN_5000_dy_nHis2_aug0.05/net_best.pth
Rollout 0 / 500
Traceback (most recent call last):
  File "/home1/user/gnn_exp/SGNN/RigidFall/eval_new.py", line 113, in <module>
    group_gt = get_env_group(args, n_particle, scene_params, use_gpu=use_gpu)
  File "/home1/user/gnn_exp/SGNN/RigidFall/data_new.py", line 455, in get_env_group
    norm_g = normalize_scene_param(scene_params, 1, args.physics_param_range)
  File "/home1/user/gnn_exp/SGNN/RigidFall/data_new.py", line 122, in normalize_scene_param
    normalized = np.copy(scene_params[param_idx])
IndexError: index 1 is out of bounds for dimension 0 with size 1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.