Code Monkey home page Code Monkey logo

bknyaz / sgg Goto Github PK

View Code? Open in Web Editor NEW
127.0 5.0 20.0 22.33 MB

Train Scene Graph Generation for Visual Genome and GQA in PyTorch >= 1.2 with improved zero and few-shot generalization.

Home Page: https://arxiv.org/abs/2007.05756

License: Other

Python 1.56% Jupyter Notebook 98.43% Cython 0.01%
pytorch scene-graph scene-graph-generation visual-genome gqa paper-implementations computer-vision graph message-passing gan augmentations wandb generative-adversarial-network deep-learning

sgg's Introduction

Scene Graph Generation

Object Detections Ground truth Scene Graph Generated Scene Graph

In this visualization, woman sitting on rock is a zero-shot triplet, which means that the combination of woman, sitting on and rock has never been observed during training. However, each of the object and predicate has been observed, but together with other objects and predicate. For example, woman sitting on chair has been observed and is not a zero-shot triplet. Making correct predictions for zero-shots is very challenging, so in our papers [1,2] we address this problem and improve zero-shot as well as few-shot results. See examples of zero-shots in the Visual Genome (VG) dataset at Zero_Shot_VG.ipynb.

This repository accompanies two papers:

See the code for my another ICCV 2021 paper Context-aware Scene Graph Generation with Seq2Seq Transformers at https://github.com/layer6ai-labs/SGG-Seq2Seq.

The code in this repo is based on the amazing code for Neural Motifs by Rowan Zellers. Our code uses torchvision.models.detection, so can be run in PyTorch 1.2 or later.

Weights and Biases

Weights and Biases is a cool tool to track your machine learning experiments that I used in this project. It is free (in most cases) and very user-friendly, which is very helpful for complex projects with lots of metrics (like SGG).

See our Weights and Biases (W & B) project for the results on different SGG metrics and training curves.

Requirements

  • Python >= 3.6
  • PyTorch >= 1.2
  • Other standard Python libraries

Should be enough to install these libraries (in addition to PyTorch):

conda install -c anaconda h5py cython dill pandas
conda install -c conda-forge pycocotools tqdm

Results in our papers [1,2] were obtained on a single GPU 1080Ti/2080Ti/RTX6000 with 11-24GB of GPU memory and 32GB of RAM. MultiGPU training is unfortunately not supported in this repo.

To use the edge feature model from Rowan Zellers' model implementations (default argument -edge_model motifs in our code), it is necessary to build the following function:

cd lib/draw_rectangles; python setup.py build_ext --inplace; cd ../..;

Data

Visual Genome or GQA data will be automatically downloaded after the first call of python main.py -data $data_path. After downloading, the script will generate the following directories (make sure you have at least 60GB of disk space in $data_path):

data_path
│   VG
│   │   VG.tar
│   │   VG_100K (this will appear after extracting VG.tar)
│   │   ...
│
└───GQA # optional
│   │   GQA_scenegraphs.tar
│   │   sceneGraphs (this will appear after extracting GQA_scenegraphs.tar)
|   |   ...

If downloading fails, you can download manually using the links from lib/download.py. Alternatively, the VG can be downloaded following Rowan Zellers' instructions, while GQA can be downloaded from the GQA official website.

To train SGG models on VG, download Rowan Zellers' VGG16 detector checkpoint and save it as ./data/VG/vg-faster-rcnn.tar.

To train our GAN models from [2], it is necessary to first extract and save real object features from the training set of VG by running:

python extract_features.py -data ./data/ -ckpt ./data/VG/vg-faster-rcnn.tar -save_dir ./data/VG/

The script will generate ./data/VG/features.hdf5 of around 30GB.

Example from [1]: Improved edge loss

Our improved edge loss from [1] can be added to any SGG model that predicts edge labels rel_dists, which is a float valued tensor of shape (M,R), where R is the total number of predicate classes (e.g. 51 in Visual Genome). M is the total number of edges in a batch of scene graphs, including the background edges (edges without any semantic relationships).

The baseline loss used in most SGG works simply computes the cross-entropy between rel_dists and ground truth edge labels rel_labels (an integer tensor of length M):

baseline_edge_loss = torch.nn.functional.cross_entropy(rel_dists, rel_labels)

Our improved edge loss takes into account the extreme imbalance between the foreground and background edge terms. Foreground edges are those that have semantic ground truth annotations (e.g. on, has, wearing, etc.). In datasets like Visual Genome, scene graph annotations are extremely sparse, i.e. the number of foreground edges (M_FG) is significantly lower than the total number of edges M.

baseline_edge_loss = torch.nn.functional.cross_entropy(rel_dists, rel_labels)
M = len(rel_labels)
M_FG = torch.sum(rel_labels > 0)
our_edge_loss = baseline_edge_loss * M / M_FG

Our improved loss significantly improves all SGG metrics, in particular zero and few shots. See [1] for the results and discussion why our loss works well.

See the full code of different losses in lib/losses.py.

Example from [2]: Generative Adversarial Networks (GANs)

In this example I provide the pseudo code for adding the GAN model to a given SGG model. See the full code in main.py.

from torch.nn.functional import cross_entropy as CE

# Assume the SGG model (sgg_model) returns features for 
# nodes (nodes_real) and edges (edges_real) as well as global features (fmap_real).

# 1. Main SGG model object and relationship classification losses (L_CLS)

obj_dists, rel_dists = sgg_model.predict(nodes_real, edges_real)  # predict node and edge labels
node_loss = CE(obj_dists, gt_objects)
M = len(rel_labels)
M_FG = torch.sum(rel_labels > 0)
our_edge_loss = CE(rel_dists, rel_labels) *  M / M_FG  # use our improved edge loss from [1]

L_CLS = node_loss + our_edge_loss  # SGG total loss from [1]
L_CLS.backward()
F_optimizer.step()  # update the sgg_model (main SGG model F)

# 2. GAN-based updates

# Scene Graph perturbations (optional)
gt_objects_fake = sgp.perturb(gt_objects, gt_rels)  # we only perturb nodes (object labels)

# Generate global feature maps using our GAN conditioned on (perturbed) scene graphs
fmap_fake = gan(gt_objects_fake, gt_boxes, gt_rels)

# Extract node and edge features from fmap_fake
nodes_fake, edges_fake = sgg_model.node_edge_features(fmap_fake)

# Make SGG predictions for the node and edge features 
# Detach the gradients to avoid bad collaboration of G and F
obj_dists_fake, rel_dists_fake = sgg_model.predict(nodes_fake.detach(),
                                                   edges_fake.detach())

# 2.1. Generator (G) losses

# Adversarial losses
L_ADV_G_nodes = gan.loss(nodes_fake, labels_fake=gt_objects_fake)
L_ADV_G_edges = gan.loss(edges_fake, labels_fake=rel_labels)
L_ADV_G_global = gan.loss(fmap_fake)

# Reconstruction losses
L_REC_nodes = CE(obj_dists_fake, gt_objects_fake)
L_REC_edges = CE(rel_dists_fake, rel_labels) *  M / M_FG  # use our improved edge loss from [1]

# Total G loss
loss_G_F = L_ADV_G_nodes + L_ADV_G_edges + L_ADV_G_global + L_REC_nodes + L_REC_edges
loss_G_F.backward()
F_optimizer.step()  # update the sgg_model (main SGG model F)
G_optimizer.step()  # update the generator (G) of the GAN

# 2.1. Discriminator (D) losses

# Adversarial losses
L_ADV_D_nodes = gan.loss(node_real, nodes_fake, labels_fake=gt_objects_fake, labels_real=gt_objects)
L_ADV_D_edges = gan.loss(edge_real, edges_fake, labels_fake=rel_labels, labels_real=rel_labels)
L_ADV_D_global = gan.loss(fmap_real, fmap_fake)

# Total D loss
loss_D = L_ADV_D_nodes + L_ADV_D_edges + L_ADV_D_global
loss_D.backward()  # update the discriminator (D) of the GAN
D_optimizer.step()

Adding our GAN also consistently improves all SGG metrics. See [2] for the results, model description and analysis.

Visual Genome (VG)

SGCls/PredCls

Results of R@100 are reported below obtained using Faster R-CNN with VGG16 as a backbone. No graph constraint evaluation is used. For graph constraint results and other details, see the W&B project.

Model Paper Checkpoint W & B Zero-Shots 10-shots 100-shots All-shots
IMP+1 IMP / Neural Motifs link link 8.7 19.2 38.4 47.8
IMP++2 our BMVC 2020 link link 8.8 21.6 40.6 48.7
IMP++ with GAN3 our ICCV 2021 link link 9.3 22.2 41.5 50.0
IMP++ with GAN and GraphN scene graph perturbations4 our ICCV 2021 link link 10.2 21.7 40.9 49.8
  • 1: python main.py -data ./data -ckpt ./data/vg-faster-rcnn.tar -save_dir ./results/IMP_baseline -loss baseline -b 24

  • 2: python main.py -data ./data -ckpt ./data/vg-faster-rcnn.tar -save_dir ./results/IMP_dnorm -loss dnorm -b 24

  • 3:python main.py -data ./data -ckpt ./data/vg-faster-rcnn.tar -save_dir ./results/IMP_GAN -loss dnorm -b 24 -gan -largeD -vis_cond ./data/VG/features.hdf5

  • 4:python main.py -data ./data -ckpt ./data/vg-faster-rcnn.tar -save_dir ./results/IMP_GAN_graphn -loss dnorm -b 24 -gan -largeD -vis_cond ./data/VG/features.hdf5 -perturb graphn -L 0.2 -topk 5 -graphn_a 2

Evaluation on the VG test set will be run at the end of the training script. To re-run evaluation: python main.py -data ./data -ckpt ./results/IMP_GAN_graphn/vgrel.pth -pred_weight $x, where $x is the weight for rare predicate classes, which is 1 for default, but can be increased to improve certain metrics like mean recall (see the Appendix in our paper [2] for more details).

Generated Feature Quality

To inspect the features generated with GANs, it is necessary to first extract and save node/edge/global features. This can be done similarly to the code in extract_features.py, but replacing the real features with the ones produced by the GAN.

See this jupyter notebook to inspect generated feature quality.

Scene Graph Perturbations

See this jupyter notebook to inspect scene graph perturbation methods.

SGGen (optional)

Please follow the details in our papers to obtain SGGen/SGDet results, which are based on using the original Neural Motifs code.

Pull-requests to add training and evaluation SGGen/SGDet models with the VGG16 or another backbone are welcome.

GQA

Note: these instructions are for our BMVC 2020 paper [1] and have not been tested in the last version of the repo

SGCls/PredCls

To train an SGCls/PredCls model with our loss on GQA: python main.py -data ./data -loss dnorm -split gqa -lr 0.002 -save_dir ./results/GQA_sgcls # takes about 1 day. Or download our GQA-SGCls-1 checkpoint

In the trained checkpoints of this repo I used a slightly different edge model in UnionBoxesAndFeats -edge_model raw_boxes. To use Neural Motifs's edge model, use flag -edge_model motifs (default in the current version of the repo).

SGGen (optional)

Follow these steps to train and evaluate an SGGen model on GQA:

  1. Fine-tune Mask R-CNN on GQA: python pretrain_detector.py gqa ./data ./results/pretrain_GQA # takes about 1 day. Or download our GQA-detector checkpoint

  2. Train SGCls: python main.py -data ./data -lr 0.002 -split gqa -nosave -loss dnorm -ckpt ./results/pretrain_GQA/gqa_maskrcnn_res50fpn.pth -save_dir ./results/GQA_sgdet # takes about 1 day. Or download our GQA-SGCls-2 checkpoint. This checkpoint is different from SGCls-1, because here the model is trained on the features of the GQA-pretrained detector. This checkpoint can be used in the next step.

  3. Evaluate SGGen: python main.py -data ./data -split gqa -ckpt ./results/GQA_sgdet/vgrel.pth -m sgdet -nosave -nepoch 0 # takes a couple hours

Visualizations

See an example of detecting objects and obtaining scene graphs for GQA test images at Scene_Graph_Predictions_GQA.ipynb.

Citation

Please use these references to cite our papers or code:

@inproceedings{knyazev2020graphdensity,
  title={Graph Density-Aware Losses for Novel Compositions in Scene Graph Generation},
  author={Knyazev, Boris and de Vries, Harm and Cangea, Cătălina and Taylor, Graham W and Courville, Aaron and Belilovsky, Eugene},
  booktitle={British Machine Vision Conference (BMVC)},
  pdf={http://arxiv.org/abs/2005.08230},
  year={2020}
}
@inproceedings{knyazev2020generative,
  title={Generative Compositional Augmentations for Scene Graph Prediction},
  author={Boris Knyazev and Harm de Vries and Cătălina Cangea and Graham W. Taylor and Aaron Courville and Eugene Belilovsky},
  booktitle={International Conference on Computer Vision (ICCV)},
  pdf={https://arxiv.org/abs/2007.05756},
  year={2021}
}

sgg's People

Contributors

bknyaz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

sgg's Issues

NO OBJECTS OR RELATIONS FOUND list index out of range ['1001.jpg'] trying a smaller threshold

Your project is very helpful to me. Thank you for always.

Let me explain my issue.
I solved the 4th issue below and encountered a problem.
#4

When I run the last source code of the Jupiter notebook file you provided, I get the following error.
I want to solve this problem.

# Let's visualize ground truth and predictions for test scene graphs
val_epoch(test_loader, 'test_zs', n_batches=10, max_obj=30, max_rels=50)
  • ERROR
Evaluate TEST_ZS test triplets

Evaluating SGDET...
NO OBJECTS OR RELATIONS FOUND list index out of range ['1001.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1001.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1001.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1016.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1016.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1016.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1054.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1054.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1054.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1057.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1057.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1057.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1058.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1058.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1058.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1066.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1066.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1066.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1069.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1069.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1069.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1074.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1074.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['1074.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107903.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107903.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107903.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107905.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107905.jpg'] trying a smaller threshold
NO OBJECTS OR RELATIONS FOUND list index out of range ['107905.jpg'] trying a smaller threshold

Cannot download data

Hey @bknyaz, thanks for your sharing!

I found that data of GQA cannot be downloaded via the link the script provided.
Could you check it or provide us new links.

There was official website for GQA. But their downloading link is broken somehow recently...

Thanks in advance!

How to create a scenegraph from a custom image?

Your research is very good, and it is very helpful to me. Thank you for always.

I would like to proceed with creating a scenegraph for my custom image.

Can you help me how I can do it?

I would like to modify the Jupiter notebook source code provided by you, if possible, to give a custom image as input.

  • I mean, I don't add learning data, I just want to see the scene graph as the image I specified!

Predicting attributes

Hello! Thanks for the great job! I wanted to ask if you have anywhere version which also predicts attributes for GQA?

FileNotFoundError: [Errno 2] No such file or directory: 'data_path/GQA/GQA_scenegraphs.tar'

Thanks for the amazing work! I am trying to reproduce your work but encountered the following error. I guess it is probably because the data link is not accessible:

$ python main.py -data data_path -loss dnorm -save_dir VG_sgcls

GQA data_path/GQA
Downloading GQA (can take a few hours
downloading data_path/GQA/GQA_scenegraphs.tar ...
/data_path/GQA/GQA_scenegraphs.tar: No such file or directory
return code for data_path/GQA/GQA_scenegraphs.tar = 1
extracting data_path/GQA/GQA_scenegraphs.tar to data_path/GQA
Traceback (most recent call last):
  File "main.py", line 11, in <module>
    conf = ModelConfig()
  File "/userhome/34/h3509807/scene_graph/sgg/config.py", line 120, in __init__
    download_all_data(self.data)
  File "/userhome/34/h3509807/scene_graph/sgg/lib/download.py", line 48, in download_all_data
    download(link, data_dir)
  File "/userhome/34/h3509807/scene_graph/sgg/lib/download.py", line 28, in download
    tar = tarfile.open(filename)
  File "/userhome/31/h3509807/anaconda3/envs/pytorch/lib/python3.6/tarfile.py", line 1571, in open
    return func(name, "r", fileobj, **kwargs)
  File "/userhome/31/h3509807/anaconda3/envs/pytorch/lib/python3.6/tarfile.py", line 1636, in gzopen
    fileobj = gzip.GzipFile(name, mode + "b", compresslevel, fileobj)
  File "/userhome/31/h3509807/anaconda3/envs/pytorch/lib/python3.6/gzip.py", line 163, in __init__
    fileobj = self.myfileobj = builtins.open(filename, mode or 'rb')
FileNotFoundError: [Errno 2] No such file or directory: 'data_path/GQA/GQA_scenegraphs.tar'

Problems in pre-trained faster-rcnn detector

Hi, thank you for sharing these wonderful works!

I found a problem in loading the pre-trained file 'vg-faster-rcnn.tar'.
The anchor ratios and anchor scales in neural-motifs are inconsistent with the torchvision.models.detection
motifs
anchor ratios: (0.23232838, 0.63365731, 1.28478321, 3.15089189); scales: (2.22152954, 4.12315647, 7.21692515, 12.60263013, 22.7102731)
torchvision
anchor ratios: (0.5, 1.0, 2.0); scales: (32, 64, 128, 256, 512).
Thus the pre-trained weights 'vg-faster-rcnn.tar' mismatch the torchvision in rpn.head.bbox_pred (120, 512, 1, 1) vs (60, 512, 1, 1).

I don't know if my analysis above is correct and if this will affect the performance of rpn.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.