Code Monkey home page Code Monkey logo

shinkyo0513 / towards-visually-explaining-video-understanding-networks-with-perturbation Goto Github PK

View Code? Open in Web Editor NEW
19.0 19.0 1.0 170.42 MB

Attribution (or visual explanation) methods for understanding video classification networks. Demo codes for WACV2021 paper: Towards Visually Explaining Video Understanding Networks with Perturbation.

Python 100.00%
action-recognition attribution-methods deep-learning-visualization epic-kitchens interpretable-deep-learning video-classification

towards-visually-explaining-video-understanding-networks-with-perturbation's People

Contributors

dependabot[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

clic-ethiopia

towards-visually-explaining-video-understanding-networks-with-perturbation's Issues

Is this the codebase for the WACV 2021 paper?

Hi,

Thanks a lot for this fantastic code repo.

Could you confirm whether this code repo can reproduce the results in the paper "Towards Visually Explaining Video Understanding Networks with Perturbation"?

Thank you very much!

missing UCF101-24 in the codebase

Hi,

You have included results on UCF101-24 in your paper "Towards Visually Explaining Video Understanding Networks with Perturbation". However, in the codebase, there is no option for this dataset, e.g., the argument option pretrain_dataset has no option for UCF101-24.

It could be really helpful if you add the corresponding code or give a short description of how to use the existing codebase on UCF101-24.

Thanks!

UCF101 kernel size issue for STEP algorithm

Hi,

I added UCF-101 to the codebase and tested STEP, 2D_EP, 3D_EP, the latter two run fine on the whole UCF-101 validation dataset, while STEP fails for some videos, with the error shown on the picture

Screenshot from 2021-01-29 21-08-45

If you could help with this, I would really appreciate it.

Thank you!

Problems about grad_cam

Hello! Thanks for your excellent and clean code for video model visualization.
However, it seems that there are some problems in the code visual_meth/grad_cam. To be specific, the problems occurs when running grad_cam multiple times in main.py:

  1. The backward_hook and forward_hook will be registered multiple times.
  2. The observ_actv will always use the first one instead of the current one.
  3. Use label instead of pred to set the output for to calculate the gradients.

The code after my correction is as follows:

import torch
import torch.nn.functional as F
import numpy as np 

# Backward hook
observ_grad_ = None
def backward_hook(m, i_grad, o_grad): 
    global observ_grad_
    observ_grad_ = o_grad[0].detach()

# Forward hook
observ_actv_ = None
def forward_hook(m, i, o):
    global observ_actv_
    observ_actv_ = o.detach()

def grad_cam (inputs, labels, model, device, layer_name, norm_vis=True):
    model.eval()   # Set model to evaluate mode
    
    bs, ch, nt, h, w = inputs.shape
    assert ch == 3
    assert labels.shape[0] == bs

    # layer_dict = dict(model.module.named_children())
    # assert layer_name in layer_dict, \
    #     f'Given layer ({layer_name}) is not in model. {model}'
    # observ_layer = layer_dict[layer_name]

    observ_layer = model
    for name in layer_name:
        # print(dict(observ_layer.named_children()).keys())
        observ_layer = dict(observ_layer.named_children())[name]

    # print(model)
    # raise Exception('just a test')

    backward_handle = observ_layer.register_backward_hook(backward_hook)
    forward_handle = observ_layer.register_forward_hook(forward_hook)

    # print(f'before len(observ_actv_): {len(observ_actv_)}')
    # print(f'before len(observ_grad_): {len(observ_grad_)}')

    inputs = inputs.to(device)
    labels = labels.to(dtype=torch.long)

    # Forward pass
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)

    observ_actv = observ_actv_   # 1 x C x num_f/8 x 56 x 56
    # print('observ_actv:', observ_actv.shape)
    observ_actv = torch.repeat_interleave(observ_actv, int(nt/observ_actv.shape[2]), dim=2)

    # backward pass
    backward_signals = torch.zeros_like(outputs, device=device)
    for bidx in range(bs):
        backward_signals[bidx, preds[bidx].cpu().item()] = 1.0
    outputs.backward(backward_signals)

    observ_grad = observ_grad_   # 1 x C x num_f/8 x 56 x 56
    # print('observ_grad:', observ_grad.shape)
    observ_grad = torch.repeat_interleave(observ_grad, int(nt/observ_grad.shape[2]), dim=2)

    # print(f'after len(observ_actv_): {len(observ_actv_)}')
    # print(f'after len(observ_grad_): {len(observ_grad_)}')
    # print(f'observ_actv_[0].shape: {observ_actv_[0].shape}')
    # print(f'observ_grad[0].shape: {observ_grad[0].shape}')
    # print(f'preds: {preds}')

    backward_handle.remove()
    forward_handle.remove()

    observ_grad_w = observ_grad.mean(dim=4, keepdim=True).mean(dim=3, keepdim=True) # 1 x 512 x num_f x 1x1
    out_masks = F.relu( (observ_grad_w*observ_actv).sum(dim=1, keepdim=True) ) # 1 x 1 x num_f x 14x14

    if norm_vis:
        out_masks = (out_masks - out_masks.min()) / (out_masks.max() - out_masks.min())

    return out_masks

Could you confirm whether such modification is correct? Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.