Code Monkey home page Code Monkey logo

csinva / transformation-importance Goto Github PK

View Code? Open in Web Editor NEW
8.0 6.0 1.0 77.45 MB

Using / reproducing TRIM from the paper "Transformation Importance with Applications to Cosmology" ๐ŸŒŒ (ICLR Workshop 2020)

Home Page: https://arxiv.org/abs/2003.01926

License: MIT License

Python 1.13% Jupyter Notebook 98.87% Shell 0.01%
pytorch interpretability neural-network machine-learning interpretation explainability transformation deep-learning explainable-ai artificial-intelligence ml ai data-science deep-neural-networks feature-importance feature-engineering transform frequency-domain wavelet-analysis attribution

transformation-importance's Introduction

Official code for using / reproducing TRIM from the paper Transformation Importance with Applications to Cosmology (ICLR 2020 Workshop). This code shows examples and provides useful wrappers for calculating importance in a transformed feature space.

This repo is actively maintained. For any questions please file an issue.

trim

examples/documentation

  • dependencies: depends on the pip-installable acd package
  • examples: different folders (e.g. ex_cosmology, ex_fake_news, ex_mnist, ex_urban_sound contain examples for using TRIM in different settings)
  • src: the core code is in the trim folder, containing wrappers and code for different transformations
  • requirements: tested with python 3.7 and pytorch > 1.0
Attribution to different scales in cosmological images Fake news attribution to different topics
Attribution to different NMF components in MNIST classification Attribution to different frequencies in audio classification

sample usage

import torch
import torch.nn as nn
from trim import TrimModel
from functools import partial

# setup a trim model
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1)) # orig model
transform = partial(torch.rfft, signal_ndim=1, onesided=False) # fft
inv_transform = partial(torch.irfft, signal_ndim=1, onesided=False) # inverse fft
model_trim = TrimModel(model=model, inv_transform=inv_transform) # trim model

# get a data point
x = torch.randn(1, 10)
s = transform(x)

# can now use any attribution method on the trim model
# get (input_x_gradient) attribution in the fft space
s.requires_grad = True
model_trim(s).backward()
input_x_gradient = s.grad * s
  • see notebooks for more detailed usage

related work

  • ACD (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • CDEP (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • DAC (arXiv 2019 pdf, github) - finds disentangled interpretations for random forests
  • PDR framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

reference

  • feel free to use/share this code openly
  • if you find this code useful for your research, please cite the following:
@article{singh2020transformation,
    title={Transformation Importance with Applications to Cosmology},
    author={Singh, Chandan and Ha, Wooseok and Lanusse, Francois, and Boehm, Vanessa, and Liu, Jia and Yu, Bin},
    journal={arXiv preprint arXiv:2003.01926},
    year={2020},
    url={https://arxiv.org/abs/2003.01926},
}

transformation-importance's People

Contributors

csinva avatar haywse avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

Forkers

haywse

transformation-importance's Issues

Missing transforms Folder

I can't find the transforms folder inside the trim folder, may has deleted. This transforms folder contains class transform_wrappers How can I get this folder or useful wrappers for calculating importance in a transformed feature space.

When I run this file (transformation-importance/ex_fake_news/01_topic_model.ipynb) I got an error.

Thank you if you can help me.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.