Code Monkey home page Code Monkey logo

s2-transformer's Introduction

S2 Transformer for Image Captioning [IJCAI 2022]

Official code implementation for the paper S2 Transformer for Imgae Captioning
Pengpeng Zeng, Haonan Zhang, Jingkuan Song, and Lianli Gao

Relationship-Sensitive Transformer

Table of Contents

Environment setup

Clone this repository and create the m2release conda environment using the environment.yml file:

conda env create -f environment.yaml
conda activate m2release

Then download spacy data by executing the following command:

python -m spacy download en_core_web_md

Note

Python 3 is required to run our code. If you suffer network problems, please download en_core_web_md library from here, unzip and place it to /your/anaconda/path/envs/m2release/lib/python*/site-packages/

Data Preparation

  • Annotation. Download the annotation file m2_annotations [1]. Extract and put it in the project root directory.
  • Feature. Download processed image features ResNeXt-101 and ResNeXt-152 features [2] (code 9vtB), put it in the project root directory.

Update: Image features on OneDrive

Training

Run python train_transformer.py using the following arguments:

Argument Possible values
--exp_name Experiment name
--batch_size Batch size (default: 50)
--workers Number of workers, accelerate model training in the xe stage.
--head Number of heads (default: 8)
--resume_last If used, the training will be resumed from the last checkpoint.
--resume_best If used, the training will be resumed from the best checkpoint.
--features_path Path to visual features file (h5py)
--annotation_folder Path to annotations
--num_clusters Number of pseudo regions

For example, to train the model, run the following command:

python train_transformer.py --exp_name S2 --batch_size 50 --m 40 --head 8 --features_path /path/to/features --num_clusters 5

or just run:

bash train.sh

Note

We apply torch.distributed to train our model, you can set the worldSize in train_transformer.py to determine the number of GPUs for your training.

Evaluation

Offline Evaluation.

Run python test_transformer.py to evaluate the model using the following arguments:

python test_transformer.py --batch_size 10 --features_path /path/to/features --model_path /path/to/saved_transformer_models/ckpt --num_clusters 5

Tip

We have removed the SPICE evaluation metric during training because it is time-cost. You can add it when evaluating the model: download this file and put it in /path/to/evaluation/, then uncomment codes in init.py.

We provide checkpoint here, you will get the following results (second row):

Model B@1 B@4 M R C S
Our Paper (ResNext101) 81.1 39.6 29.6 59.1 133.5 23.2
Reproduced Model (ResNext101) 81.2 39.9 29.6 59.1 133.7 23.3

Online Evaluation

We also report the performance of our model on the online COCO test server with an ensemble of four S2 models. The detailed online test code can be obtained in this repo.

Reference and Citation

Reference

[1] Cornia, M., Stefanini, M., Baraldi, L., & Cucchiara, R. (2020). Meshed-memory transformer for image captioning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition.
[2] Xuying Zhang, Xiaoshuai Sun, Yunpeng Luo, Jiayi Ji, Yiyi Zhou, Yongjian Wu, Feiyue Huang, and Rongrong Ji. Rstnet: Captioning with adaptive attention on visual and non-visual words. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15465–15474, 2021.

Citation

@inproceedings{S2,
  author    = {Pengpeng Zeng* and
               Haonan Zhang* and
               Jingkuan Song and 
               Lianli Gao},
  title     = {S2 Transformer for Image Captioning},
  booktitle = {IJCAI},
  pages     = {1608--1614}
  year      = {2022}
}

Acknowledgements

Thanks Zhang et.al for releasing the visual features (ResNeXt-101 and ResNeXt-152). Our code implementation is also based on their repo.
Thanks for the original annotations prepared by M2 Transformer, and effective visual representation from grid-feats-vqa.

s2-transformer's People

Contributors

trellixvulnteam avatar zchoi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

s2-transformer's Issues

how to get the attention mask?

您好,我想生成与论文中类似的heat map,但是不知道如何获得 attention mask?想问一下您是怎么实现的?
谢谢您的回复~

关于 online test 代码

作者,您好,请问你能公布一下online test的代码吗?我修改并运行RSTNet的online test会报一些错误。

代码实现

您好,我看到您论文中说是实验是在未做任何改变的普通transformer上进行的,不过我在代码中看到有关M2的先验知识,请问这个有包含在您的结果内吗?

实验结果复现异常

你好,我在复现实验的过程中,进行xe训练时,各项数据是正常的,逐步增加。在RL训练过程中,前几个epoch BLEU_4是正常的,随后Test的评估Bleu_4急剧下降(40.2降到39.1),再后面一直维持到39.2左右,使用的是源代码,没有修改过,评估时使用的X101_grid_feats_coco_trainval.hdf5文件。请问我的实验是有什么细节没有修改吗? @zchoi

关于实验结果指标不正常

作者,您好,我使用您的代码进行实验结果复现,感觉实验结果始终与您的结果有所差距,想请教您是不是github上的代码与您的代码有区别,还是有哪些参数需要调整?

1H2JSWXREJH}_)A)G~M8P7](https://user-images.githubusercontent.com/55089176/190085362-7250ea28-da91-4a3e-aa28-78aabbed048f.jpg) ![PR23(RZJ3$2YCP)TIPRLPO
W7}D%UF6_)U0{`BVR2DD AV
A}(L98P)BGWY@BX EEGD8XC

log文件打开乱码

作者你好,我打开你提供的log文件是乱码的,请问怎么解决呢

关于单卡训练文件

作者,您好,我想问下如果我想要用单卡训练,是否可以直接将worldsize改为1就可以进行单卡训练?
以及如果我想使用rstnet模型的train_transformer文件进行训练,是否可以达到相同的评价指标

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.