We present iSEE, a framework that allows interpreting the dynamic representation of the navigation agents in terms of human interpretable concepts. The navigation agents were trained using AllenAct framework. In this repository, we provide:
- Dataset of RNN activations and concepts
- Code to evaluate how well trained agents predict concepts
- Code to get top-K relevant neurons for predicting a given concept
If you find this project useful in your research, please consider citing:
@InProceedings{Dwivedi_2022_CVPR,
author = {Dwivedi, Kshitij and Roig, Gemma and Kembhavi, Aniruddha and Mottaghi, Roozbeh},
title = {What Do Navigation Agents Learn About Their Environment?},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {10276-10285}
}
To begin, clone this repository locally
git clone https://github.com/allenai/iSEE.git
Install anaconda and create a new conda environment
conda create -n iSEE
conda activate iSEE
Install xgboost-gpu using the following command
conda install -c anaconda py-xgboost-gpu
Install other requirements
pip install -r requirements.txt
Please download the dataset from here. Then unzip it inside data directory
Run the following script to evaluate how well concepts can be predicted by trained agent (Resnet-objectnav) and compare it to corresponding untrained baseline
python predict concepts.py --model resnet --task objectnav
Arguments:
--model
: We used two architectures Resnet and SimpleConv. Options areresnet
andsimpleconv
--task
: Options areobjectnav
andpointnav
The script will generate plots and save them in results/task/model/plots
directory
Run the following script to find which neurons were most relevant in predicting a given concept (e.g. front reachability) by a trained agent (Resnet-objectnav).
python get_topk_neurons.py --model resnet --task objectnav --concept reachable_R=2_theta=000
Arguments:
--model
: We used two architectures Resnet and SimpleConv. Options areresnet
andsimpleconv
--task
: Options areobjectnav
andpointnav
--concept
: The concepts used in the paper arereachable_R=2_theta=000
(Reachability at 2xgridSize and front) andtarget_visibility
. For full list of concepts in the dataset, please refer to columns ofdata/trajectory_dataset/train/objectnav_ithor_default_resnet_pretrained/metadata.pkl
file.
The script will generate SHAP beeswarm plot for the concept and save it in results/task/model/shap_plots
directory.
We thank the SHAP authors for easy to use code and ManipulaThor authors for Readme template.