Code Monkey home page Code Monkey logo

t-mass-text-video-retrieval's Introduction

Text Is MASS: Modeling as Stochastic Embedding for Text-Video Retrieval (CVPR 2024 Highlight)

Jiamian Wang, Guohao Sun, Pichao Wang, Dongfang Liu, Sohail Dianat, Majid Rabbani, Raghuveer Rao, Zhiqiang Tao, "Text Is MASS: Modeling as Stochastic Embedding for Text-Video Retrieval".

[Paper] [Pretrained Models]


Abstract: The increasing prevalence of video clips has sparked growing interest in text-video retrieval. Recent advances focus on establishing a joint embedding space for text and video, relying on consistent embedding representations to compute similarity. However, the text content in existing datasets is generally short and concise, making it hard to fully describe the redundant semantics of a video. Correspondingly, a single text embedding may be less expressive to capture the video embedding and empower the retrieval. In this study, we propose a new stochastic text modeling method T-MASS, i.e., text is modeled as a stochastic embedding, to enrich text embedding with a flexible and resilient semantic range, yielding a text mass. To be specific, we introduce a similarity-aware radius module to adapt the scale of the text mass upon the given text-video pairs. Plus, we design and develop a support text regularization to further control the text mass during the training. The inference pipeline is also tailored to fully exploit the text mass for accurate retrieval. Empirical evidence suggests that T-MASS not only effectively attracts relevant text-video pairs while distancing irrelevant ones, but also enables the determination of precise text embeddings for relevant pairs. Our experimental results show a substantial improvement of T-MASS over baseline (3%~6.3% by R@1). Also, T-MASS achieves state-of-the-art performance on five benchmark datasets, including MSRVTT, LSMDC, DiDeMo, VATEX, and Charades.


Content

  1. Dependencies
  2. Dataset
  3. Evaluation
  4. Training
  5. Citation
  6. Acknowledgement
  7. Contact

Dependencies

  • PyTorch 1.12.1
  • OpenCV 4.7.0
  • transformers 4.30.2

Dataset

To download MSRVTT, LSMDC, and DiDeMo, please follow CLIP4Clip.

You will need to request a permission from MPII to download and use the Standard LSMDC data.

For LSMDC, download the data split csv files into ./data/LSMDC/.

For DiDeMo, using gdrive to download video data is recommended. One may consider

  • Setup gdrive by following "Getting started".
  • Download video data by gdrive files download --recursive FOLDER_ID_FROM_URL
Dataset Video Clips Text-Video Pairs Scale Link
MSR-VTT 10K one-to-twenty 6.7Gb link
LSMDC 118081 one-to-one 1.3Tb link
DiDeMo 10464 one-to-many 581Gb link

Evaluation

Download the checkpoints into ./outputs/{Dataset}/{FOLDER_NAME_UNDER_*Dataset*}.

Repeat testing process for --stochasic_trials causes either time or memory computational overhead. The sequential strategy provided is more memory-friendly. We adopt --seed=24 and --stochasic_trials=20 for all methods. One may consider specifying --save_memory_mode for larger datasets or computational-constrained platforms at evaluation. Same as XPool, the evaluation is default to text-to-video retrieval performance (i.e., --metric=t2v), for video-to-text retrieval performance, specify --metric=v2t. For post processing operation evaluation results of DSL, specify --DSL.

Replace {videos_dir} with the path to the dataset.

Dataset Command Checkpoint File t2v R@1 Result
MSR-VTT-9k python test.py --datetime={FOLDER_NAME_UNDER_MSR-VTT-9k} --arch=clip_stochastic --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=3e-5 --transformer_dropout=0.3 --dataset_name=MSRVTT --msrvtt_train_file=9k --stochasic_trials=20 --gpu='0' --load_epoch=0 --exp_name=MSR-VTT-9k Link 50.2
LSMDC python test.py --arch=clip_stochastic --exp_name=LSMDC --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=1e-5 --transformer_dropout=0.3 --dataset_name=LSMDC --stochasic_trials=20 --gpu='0' --num_epochs=5 --stochastic_prior=normal --stochastic_prior_std=3e-3 --load_epoch=0 --datetime={FOLDER_NAME_UNDER_LSMDC} Link 28.9
DiDeMo python test.py --num_frame=12 --raw_video --arch=clip_stochastic --exp_name=DiDeMo --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=1e-5 --transformer_dropout=0.4 --dataset_name=DiDeMo --stochasic_trials=20 --gpu='0' --num_epochs=5 --load_epoch=0 --datetime={FOLDER_NAME_UNDER_DiDeMo} Link 50.9

Training

Run the following training code to resume the above results. Take MSRVTT as an example, one may consider support text regularization by specifying --support_loss_weight. --evals_per_epoch can be enlarged to select a better checkpoint. The CLIP model is default to --clip_arch=ViT-B/32. To train on a larger CLIP backbone, speficy --clip_arch=ViT-B/16. One may enlarge the training epochs --num_epochs by one or two when the dataset is incomplete for a better performance.

Dataset Command
MSR-VTT-9k python train.py --arch=clip_stochastic --exp_name=MSR-VTT-9k --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=3e-5 --transformer_dropout=0.3 --dataset_name=MSRVTT --msrvtt_train_file=9k --stochasic_trials=20 --gpu='0' --num_epochs=5 --support_loss_weight=0.8
LSMDC python train.py --arch=clip_stochastic --exp_name=LSMDC --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=1e-5 --transformer_dropout=0.3 --dataset_name=LSMDC --stochasic_trials=20 --gpu='0' --num_epochs=5 --stochastic_prior=normal --stochastic_prior_std=3e-3
DiDeMo python train.py --num_frame=12 --raw_video --arch=clip_stochastic --exp_name=DiDeMo --videos_dir={VIDEO_DIR} --batch_size=32 --noclip_lr=1e-5 --transformer_dropout=0.4 --dataset_name=DiDeMo --stochasic_trials=20 --gpu='0' --num_epochs=5

Citation

If you find this work valuable for your research, we kindly request that you cite the following paper:

@inproceedings{wang2024text,
  title={Text Is MASS: Modeling as Stochastic Embedding for Text-Video Retrieval}, 
  author={Wang, Jiamian and Sun, Guohao and Wang, Pichao and Liu, Dongfang and Dianat, Sohail and Rabbani, Majid and Rao, Raghuveer and Tao, Zhiqiang},
  booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
  year={2024}
}

Acknowledgement

This code is built on XPool. Great thanks to them!

Contact

For discussions, please feel free to submit an issue or contact me via email at [email protected].

t-mass-text-video-retrieval's People

Contributors

jiamian-wang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

t-mass-text-video-retrieval's Issues

Figure 3. Dynamics of R

Hello author, thank you for your excellent work. I would like to ask how Figure 3. Dynamics of R is drawn in the paper? Can you provide the code?

About VATEX

Hi, I noticed that there is no relevant code for the VATEX dataset in the current code, such as the code in the data and datasets directories. Could you please provide the VATEX-related code and the pretrained weights?

About inference similarity computation in 'sim_matrix_inference_stochastic'

hello author, thank you for sharing your excellent work.

when I tried to understand how you compute the similarity between fused video embeds and num_txts generated stochastic text embeddings from a pooled video embedding, I had a doubt about permute operation
vid_embeds_pooled_per_video_id = vid_embeds_pooled_per_video_id.permute(1, 2, 3, 0)

this pooled video embeds' shape is (b, a, 1, 512), which means every video has been fused with all text embeds in the whole validation set, and a video has a fused video embeds, where b is num_vids and a is num_txts and in validation they are same.

after permute operation, the pooled video tensor shape is (a, 1, 512, b), which means evey text has fused video embeddings with all videos embeddings respectivelybeen fused with video embeds and obtained num_vids fused video embeds.

and the question is in bmm operation :
sims = torch.bmm(text_embeds_per_video_id,vid_embeds_pooled_per_video_id)
the shape are (b, a, 512) and (a, 512, b) for text and video respectively, and the batch multiply operation means for a postive sample pairs in validation set, i.g. i th video and i th text. For i th video's j th stochastic text embeds, it needs to compute similarities with all fused video embeds that was fused with i th text embed.

This thought of combination reallly confused me for a while. Why the pooled video tensor needs to permute to shape of (a, 1, 512, b) and why not straightly performs bmm operation : (b, a, 512) x (b, 512, a) -> (b, a, a) for text and video respectivelyreally confused me for a while. Why the pooled video tensor needs to permute to shape of (a, 1, 512, b) and why not straightly performs bmm operation : (b, a, 512) x (b, 512, a) -> (b, a, a) which means for i th video and i th text pair, each of stochastic text embed -- generated by i th video embed and its corresponding text embed -- needs to compute similarities between all fused video embeds generated by i th text embed and all video embeds ?

Sincerely to waiting for your reply.

the below is doubted part function of sim_matrix_inference_stochastic

`

num_txts, num_vids, max_text_per_vid, embed_dim = text_embeds_per_video_id.shape # (b,a=b,1,512)

vid_embeds_pooled_per_video_id = vid_embeds_pooled_per_video_id.permute(1, 2, 3, 0) # (a,1,512,b)
vid_embeds_pooled_per_video_id = vid_embeds_pooled_per_video_id.reshape(num_vids * max_text_per_vid, embed_dim,
                                                                        num_vids) # (a,512,b)
text_embeds_per_video_id = text_embeds_per_video_id.permute(0, 2, 1, 3) # (b,1,a,512)
text_embeds_per_video_id = text_embeds_per_video_id.reshape(num_vids * max_text_per_vid, num_txts, embed_dim) # (b,a,512)


sims = torch.bmm(text_embeds_per_video_id,vid_embeds_pooled_per_video_id)
# (b,a,512)x(a,512,b)->(b=a,a,b) , b=a means num_vids == num_txts in validation set

sims = sims.view(num_vids, max_text_per_vid, num_txts, num_vids) # (b=a,1,a,b)
sims_diag = torch.stack([sims[i, :, :, i] for i in range(sims.shape[0])], dim=-1)
print(f'>>>check sims_diag={sims_diag.shape}')
sims_diag = sims_diag.permute(1, 0, 2)

`

About the training scripts.

Thank you to the author for sharing the open-source code. I noticed that the official training scripts have slightly different settings for different datasets. For example, the MSRVTT dataset uses support_loss_weight, but the other two datasets do not. For the LSMDC dataset, the stochastic prior is set to normal, and std is set to 3e-3, but these settings are not applied to the other two datasets. For the DiDeMo dataset, there are no settings for support_loss_weight, stochastic prior, and std. I would like to know if it is indeed necessary to slightly modify the training parameters when training on different datasets, or if there are errors in the current training scripts?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.