Code Monkey home page Code Monkey logo

imma's Introduction

Interaction Modeling with Multiplex Attention

Authors: Fan-Yun Sun, Isaac Kauvar, Ruohan Zhang, Jiachen Li, Mykel Kochenderfer, Jiajun Wu, Nick Haber

Abstract: Modeling multi-agent systems requires understanding how agents interact. Such systems are often difficult to model because they can involve a variety of types of interactions that layer together to drive rich social behavioral dynamics. Here we introduce a method for accurately modeling multi-agent systems. We present Interaction Modeling with Multiplex Attention (IMMA), a forward prediction model that uses a multiplex latent graph to represent multiple independent types of interactions and attention to account for relations of different strengths. We also introduce Progressive Layer Training, a training strategy for this architecture. We show that our approach outperforms state-of-the-art models in trajectory forecasting and relation inference, spanning three multi-agent scenarios: social navigation, cooperative task achievement, and team sports. We further demonstrate that our approach can improve zero-shot generalization and allows us to probe how different interactions impact agent behavior.

This repository contains the codes for our paper, which is accepted at NeurIPs 2022. For more details, please refer to the paper (arxiv, openreview).

Environment Setup

  1. Install Python-RVO2 library
  2. Install socialforce library
  3. Install necessary packages with pip
pip install -r requirements.txt

Data Setup

Social Navigation Environment

This is a simulated environment inspired by https://github.com/vita-epfl/CrowdNav. After installing necessary dependencies, refer to the following sample commands to start the simulation.

randomseed=17
dataset_size=100000
obs_frames=24
rollouts=10

cd data_utils/socialnav
python generate_dataset.py --dataset_size $dataset_size \
                           --randomseed $randomseed \
                           --obs_frames ${obs_frames} \
                           --rollouts ${rollouts}

The resulting dataset will be stored at datasets/*.tensor. You can make modifications to the config file dat_utils/socialnav/configs/default.py to change the simulation setting.

To inspect and interact with the environment (control the embodied agent with your arrow keys):

cd data_utils/socialnav
python human_play.py

PHASE

The preprocessed dataset is under datasets/phase/collab. To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

NBA dataset

Download the preprocessed dataset here (or run gdown 1eJbDHy3fOHfzOStf-jSuYCz_YQloQU3s) and place it under datasets. Alternatively, you can create your own dataset from raw sportVU logs (refer to this repository or the code under data_utils/bball) To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

Training and Evaluation

Find sample commands at run_socialnav.sh, run_phase.sh and run_bball.sh.

Progressive Layered Training

Loss curve visualized over the course of training IMMA with PLT on the NBA dataset. New layers are added after the model "converges". Teaser image

Citation

If you find the code or paper useful for your research, please cite our paper:

@article{sun2022interaction,
  title={Interaction Modeling with Multiplex Attention},
  author={Sun, Fan-Yun and Kauvar, Isaac and Zhang, Ruohan and Li, Jiachen and Kochenderfer, Mykel and Wu, Jiajun and Haber, Nick},
  journal={Advances in Neural Information Processing Systems},
  year={2022}
}

Acknwledgement

In htis project we use (parts of) the implementations from the following works:

We thank the authors for open sourcing their methods.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.