Code Monkey home page Code Monkey logo

diffusionret's Introduction

ใ€ICCV'2023 ๐Ÿ”ฅใ€‘DiffusionRet: Generative Text-Video Retrieval with Diffusion Model

Conference Paper

The implementation of the paper DiffusionRet: Generative Text-Video Retrieval with Diffusion Model.

In this paper, we propose a novel diffusion-based text-video retrieval framework, called DiffusionRet, which addresses the limitations of current discriminative solutions from a generative perspective.

๐Ÿ“Œ Citation

If you find this paper useful, please consider staring ๐ŸŒŸ this repo and citing ๐Ÿ“‘ our paper:

@inproceedings{jin2023diffusionret,
  title={DiffusionRet: Generative Text-Video Retrieval with Diffusion Model},
  author={Jin, Peng and Li, Hao and Cheng, Zesen and Li, Kehan and Ji, Xiangyang and Liu, Chang and Yuan, Li and Chen, Jie},
  booktitle={ICCV},
  pages={2470-2481},
  year={2023}
}
๐Ÿ’ก I also have other text-video retrieval projects that may interest you โœจ.

Video-Text as Game Players: Hierarchical Banzhaf Interaction for Cross-Modal Representation Learning
Accepted by CVPR 2023 (Highlight) | [HBI Code]
Peng Jin, Jinfa Huang, Pengfei Xiong, Shangxuan Tian, Chang Liu, Xiangyang Ji, Li Yuan, Jie Chen

Expectation-Maximization Contrastive Learning for Compact Video-and-Language Representations
Accepted by NeurIPS 2022 | [EMCL Code]
Peng Jin, Jinfa Huang, Fenglin Liu, Xian Wu, Shen Ge, Guoli Song, David Clifton, Jie Chen

Text-Video Retrieval with Disentangled Conceptualization and Set-to-Set Alignment
Accepted by IJCAI 2023 | [DiCoSA Code]
Peng Jin, Hao Li, Zesen Cheng, Jinfa Huang, Zhennan Wang, Li Yuan, Chang Liu, Jie Chen

๐Ÿ“ฃ Updates

  • [2023/08/27]: We release the training code.
  • [2023/07/14]: Our DiffusionRet has been accepted by ICCV 2023! We will release the training code asap.
  • [2023/06/28]: We release the inference code.
  • [2023/03/31]: Our paper is under review. After our paper is published, we will release the code as soon as possible.

๐Ÿ“• Overview

Existing text-video retrieval solutions are, in essence, discriminant models focused on maximizing the conditional likelihood, i.e., p(candidates|query). While straightforward, this de facto paradigm overlooks the underlying data distribution p(query), which makes it challenging to identify out-of-distribution data. To address this limitation, we creatively tackle this task from a generative viewpoint and model the correlation between the text and the video as their joint probability p(candidates,query). This is accomplished through a diffusion-based text-video retrieval framework (DiffusionRet), which models the retrieval task as a process of gradually generating joint distribution from noise.

๐Ÿš€ Quick Start

Setup

Setup code environment

conda create -n DiffusionRet python=3.9
conda activate DiffusionRet
pip install -r requirements.txt
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html

Download CLIP Model

cd DiffusionRet/models
wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
# wget https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
# wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt

Download Datasets

Datasets Google Cloud Baidu Yun Peking University Yun
MSR-VTT Download Download Download
MSVD Download Download Download
ActivityNet TODO Download Download
DiDeMo TODO Download Download

Model Zoo

Checkpoint Google Cloud Baidu Yun Peking University Yun
MSR-VTT Download Download Download
ActivityNet Download Download Download

Evaluate

Eval on MSR-VTT

CUDA_VISIBLE_DEVICES=0 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=1 \
eval.py \
--workers 8 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 32 \
--max_frames 12 \
--video_framerate 1 \
--diffusion_steps 50 \
--noise_schedule cosine \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH}

Eval on ActivityNet Captions

CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=2 \
eval.py \
--workers 8 \
--batch_size_val 128 \
--anno_path ${DATA_PATH}/ActivityNet \
--video_path ${DATA_PATH}/ActivityNet/Activity_Videos \
--datatype activity \
--max_words 64 \
--max_frames 64 \
--video_framerate 1 \
--diffusion_steps 50 \
--noise_schedule cosine \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH}

Train

Discrimination Pretrain

Train the feature extractor from the discrimination perspective.

CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=4 \
main_retrieval.py \
--do_train 1 \
--workers 8 \
--n_display 50 \
--epochs 5 \
--lr 1e-4 \
--coef_lr 1e-3 \
--batch_size 128 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 32 \
--max_frames 12 \
--video_framerate 1 \
--stage discrimination \
--output_dir ${OUTPUT_PATH}

Generation Finetune

Optimize the generator from the generation perspective.

CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.launch \
--master_port 2502 \
--nproc_per_node=4 \
main_retrieval.py \
--do_train 1 \
--workers 8 \
--n_display 50 \
--epochs 5 \
--lr 1e-4 \
--coef_lr 1e-3 \
--batch_size 128 \
--batch_size_val 128 \
--anno_path data/MSR-VTT/anns \
--video_path ${DATA_PATH}/MSRVTT_Videos \
--datatype msrvtt \
--max_words 32 \
--max_frames 12 \
--video_framerate 1 \
--stage generation \
--diffusion_steps 50 \
--noise_schedule cosine \
--init_model ${CHECKPOINT_PATH} \
--output_dir ${OUTPUT_PATH}

๐ŸŽ—๏ธ Acknowledgments

Our code is based on EMCL, CLIP, CLIP4Clip and DRL. We sincerely appreciate for their contributions.

diffusionret's People

Contributors

jpthu17 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

diffusionret's Issues

Question about diffusion model

Thanks for your excellent work, but I have a question.

p = self.decoder(emb).squeeze(2)  
p += weight

why did you add the weight to this distribution? p should contain this information. I don't understand the meaning of this step.

Question about the two stage training

Thanks for your excellent work, but i have some questions about the two stage training
(1) In the generation training stage, the gradient exclude the diffusion model is set to False, so at the second stage, the discrimination loss has no impact to the training?
(2) Have you ever tried not using two-stage training approach? How is the effect?

Code about Text-Frame Attention Encoder

Hello,

Thank you for your hard work and excellent contribution. Your work has been truly inspiring.
However, I have a question regarding section 3.2.1 of the paper, specifically the "Text-Frame Attention Encoder" mentioned in the discussion about aggregating frame representation. The final video representation is defined as Formula 4. However, I couldn't find any implementation of the text-frame attention encoder in the code. It seems that the get_video_feat function in the code uses the self.agg_video_feat function to aggregate video frames, which relies on seqTransf without any input for text features.
Could you please verify this discrepancy and provide an explanation?

Thank you!

Error during Evaluation

Hi, I managed to execute 2 phrases of training without any problems. But evaluation doesnot work, please check the following log. Thank you

(DiffusionRet) hai@user:~/sang$ CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 2502 --nnodes=1 --nproc_per_node=1 eval.py --workers 8 --batch_size_val 128 --anno_path data/MSR-VTT/anns --video_path data/MSR-VTT/
MSRVTT_Videos --datatype msrvtt --max_words 32 --max_frames 12 --video_framerate 1 --diffusion_steps 50 --noise_schedule cosine --init_model best.pth --output_dir output_eval
/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated                                                                                       
and will be removed in future. Use torchrun.                                                                                                                
Note that --use-env is set by default in torchrun.                                                                                                          
If your script expects `--local-rank` argument to be set, please                                                                                            
change it to read from `os.environ['LOCAL_RANK']` instead. See                                                                                              
https://pytorch.org/docs/stable/distributed.html#launch-utility for                                                                                         
further instructions                                                          
                                                                              
  warnings.warn(                                                              
[2023-10-03 11:06:03,359 tvr 110 INFO]: local_rank: 0 world_size: 1
[2023-10-03 11:06:03,359 tvr 117 INFO]: Effective parameters:                                                                                               
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< agg_module: seqTransf
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< anno_path: data/MSR-VTT/anns
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< base_encoder: ViT-B/32
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< batch_size: 128                            
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< batch_size_val: 128                        
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< d_temp: 100            
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< datatype: msrvtt
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< device: cuda:0   
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< diffusion_steps: 50
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< distributed: 0    
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< epochs: 5                                                                                                                           
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< init_model: best.pth        
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< interaction: wti                                                                                              
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< local_rank: 0                
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< max_frames: 12                  
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< max_words: 32                   
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< neg: 0                                                                                                                              
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< noise_schedule: cosine
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< num: 127     
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< num_hidden_layers: 4                       
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< output_dir: output_eval         
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< seed: 42                                   
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< sigma_small: True           
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< t2v_alpha: 1                                                                                                                        
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< t2v_num: 32                  
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< t2v_temp: 1                                                                                                   
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< temp: 1                   
[2023-10-03 11:06:03,359 tvr 119 INFO]:   <<< v2t_alpha: 1                               
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< v2t_num: 32                                
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< v2t_temp: 1                                                                                                                         
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< video_framerate: 1                         
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< video_path: data/MSR-VTT/MSRVTT_Videos                                                                                                                                                                      
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< workers: 8                                 
[2023-10-03 11:06:03,360 tvr 119 INFO]:   <<< world_size: 1                              
[val] Unique sentence is 995 , all num is 1000                                           
Video number: 1000                                                                       
Total Pairs: 1000                                                                                                                                                                 
[2023-10-03 11:06:10,770 tvr 159 INFO]: ***** Running test *****                         
[2023-10-03 11:06:10,770 tvr 160 INFO]:   Num examples = 1000                                                                
[2023-10-03 11:06:10,770 tvr 161 INFO]:   Batch size = 128                                                                                                                        
[2023-10-03 11:06:10,770 tvr 162 INFO]:   Num steps = 8                                  
[2023-10-03 11:06:10,770 tvr 163 INFO]: ***** Running val *****                          
[2023-10-03 11:06:10,770 tvr 164 INFO]:   Num examples = 1000                                                                                                                     
[2023-10-03 11:06:10,773 tvr 375 INFO]: [start] extract text+video feature               
100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 8/8 [01:10<00:00,  8.86s/it]                      
[2023-10-03 11:07:21,813 tvr 403 INFO]: [finish] extract text+video feature                                                                                                       
[2023-10-03 11:07:21,813 tvr 407 INFO]: 1000 1000 1000 1000                                                                                                                       
[2023-10-03 11:07:21,813 tvr 411 INFO]: [start] calculate the similarity                 
[2023-10-03 11:07:21,813 tvr 205 INFO]: [finish] map to main gpu                                                                                                                  
[2023-10-03 11:07:21,814 tvr 214 INFO]: [finish] map to main gpu                                                             
[2023-10-03 11:07:22,397 tvr 227 INFO]: diffusion                                                                                                                                 
Traceback (most recent call last):                                                       
  File "/home/hai/sang/eval.py", line 493, in <module>                                                                                                                            
    main()                                  
  File "/home/hai/sang/eval.py", line 490, in main                                                                                                                                
    eval_epoch(args, model, test_dataloader, args.device, diffusion)                                                         
  File "/home/hai/sang/eval.py", line 413, in eval_epoch                                                                                                                          
    new_t2v_matrix, new_v2t_matrix = _run_on_single_gpu(args, model, batch_mask_t,                                                                                                                                                                        
  File "/home/hai/sang/eval.py", line 255, in _run_on_single_gpu                                                             
    model.diffusion_model,                                    
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__                                                                                                                          
    raise AttributeError("'{}' object has no attribute '{}'".format(                                                         
AttributeError: 'DiffusionRet' object has no attribute 'diffusion_model'                                                     
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 259728) of binary: /home/hai/anaconda3/envs/DiffusionRet/bin/python                                                                                          
Traceback (most recent call last):                            
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/runpy.py", line 197, in _run_module_as_main                                                                                                                                                   
    return _run_code(code, main_globals, None,                                                                               
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/runpy.py", line 87, in _run_code                                                                                                                                                              
    exec(code, run_globals)                                   
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launch.py", line 196, in <module>                                                                                                                             
    main()                                                    
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launch.py", line 192, in main                                                                                                                                 
    launch(args)                                              
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launch.py", line 177, in launch                                                                                                                               
    run(args)                                                 
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/run.py", line 785, in run                                                                                                                                     
    elastic_launch(                                           
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__                                                                                                                       
    return launch_agent(self._config, self._entrypoint, list(args))                                                          
  File "/home/hai/anaconda3/envs/DiffusionRet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent                                                                                                                   
    raise ChildFailedError(                                   
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:                                                           

diffusion training loss

in diffusion_models/gaussian_diffusion.py 803-804
terms["kl_loss"] = F.cross_entropy(model_output * temp,
th.zeros(model_output.size(0), dtype=th.long).to(model_output.device))
I wonder why here uses the th.zeros(model_output.size(0), dtype=th.long) instead of x_start? Is that an error?

ๅ…ณไบŽMSVD็š„Mean Rๅœจๆ‚จ็š„ๅ‡ ไปฝๅทฅไฝœไธญๅทฎๅผ‚

ๆ‚จๅฅฝ๏ผŒๆˆ‘ๆƒณ่ฏท้—ฎไธบไป€ไนˆๅœจDiCoSA็š„ๅทฅไฝœไธญ๏ผŒT2V็š„Mean Rๆ˜ฏ9.1ๆ˜ฏๆฏ”่พƒ้ ่ฟ‘็ฑปไผผX-CLIPๆˆ–่€…ๆ˜ฏCLIP4Clipไธญๅฏนๅบ”็š„10ๅทฆๅณ็š„ๆ•ฐๅ€ผ๏ผŒไฝ†ๆ˜ฏๅœจๆ‚จ็š„DiffusionRetๆˆ–่€…ๆ˜ฏEMCL-Net็š„ไธญ็š„T2V็š„MSVD็š„Mean Rๅด้ƒฝๆ˜ฏ16้™„่ฟ‘็š„ๆ•ฐๅ€ผ๏ผŒ่ฟ™ไธŽX-CLIPๆˆ–่€…ๆ˜ฏCLIP4Clipๆˆ–่€…ๆ˜ฏๆ‚จ็š„DiCoSA็š„็›ธๅทฎ้žๅธธๅคง๏ผŒๆˆ‘ๅพˆๅฅฝๅฅ‡ๆ˜ฏไธบไป€ไนˆ๏ผŸๆ‚จ็š„MSVDๅบ”่ฏฅๆ˜ฏ็”จ่ฎญ็ปƒ็š„1200่ง†้ข‘ไธŽvalไธŽtest้ƒฝๆ˜ฏ670ๅง

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.