Code Monkey home page Code Monkey logo

vrp-sam's Introduction

VRP-SAM: SAM with Visual Reference Prompt

Update:

  1. The manuscript has been accepted in CVPR 2024.
  2. Core code has been updated

This is the official implementation based on pytorch of the paper VRP-SAM: SAM with Visual Reference Prompt

Authors: Yanpeng Sun, Jiahui Chen, Shan Zhang, Xinyu Zhang, Qiang Chen, Gang Zhang, Errui Ding, Jingdong Wang, Zechao Li

Requirements

  • Python 3.10
  • PyTorch 1.12
  • cuda 11.6

Conda environment settings:

conda create -n vrpsam python=3.10
conda activate vrpsam

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge

Segment-Anything-Model setting:

cd ./segment-anything
pip install -v -e .
cd ..

Preparing Few-Shot Segmentation Datasets

Download following datasets:

1. PASCAL-5i

Download PASCAL VOC2012 devkit (train/val data):

wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

Download PASCAL VOC2012 SDS extended mask annotations from our [Google Drive].

2. COCO-20i

Download COCO2014 train/val images and annotations:

wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip

Download COCO2014 train/val annotations from our Google Drive: [train2014.zip], [val2014.zip]. (and locate both train2014/ and val2014/ under annotations/ directory).

Create a directory '../dataset' for the above few-shot segmentation datasets and appropriately place each dataset to have following directory structure:

../                         # parent directory
├── ./                      # current (project) directory
│   ├── common/             # (dir.) helper functions
│   ├── data/               # (dir.) dataloaders and splits for each FSSS dataset
│   ├── model/              # (dir.) implementation of VRP-SAM 
│   ├── segment-anything/   # code for SAM
│   ├── README.md           # intstruction for reproduction
│   ├── train.py            # code for training HSNet
│   └── SAM2Pred.py         # code for prediction module
│    
└── Datasets_HSN/
    ├── VOC2012/            # PASCAL VOC2012 devkit
    │   ├── Annotations/
    │   ├── ImageSets/
    │   ├── ...
    │   └── SegmentationClassAug/
    └── COCO2014/           
        ├── annotations/
        │   ├── train2014/  # (dir.) training masks (from Google Drive) 
        │   ├── val2014/    # (dir.) validation masks (from Google Drive)
        │   └── ..some json files..
        ├── train2014/
        └── val2014/

Training

We provide a example training script "train.sh". Detailed training argumnets are as follows:

python3 -m torch.distributed.launch --nproc_per_node=$GPUs$ train.py \
        --datapath $PATH_TO_YOUR_DATA$ \
        --logpath $PATH_TO_YOUR_LOG$ \
        --benchmark {coco, pascal} \
        --backbone {vgg16, resnet50, resnet101} \
        --fold {0, 1, 2, 3} \
        --condition {point, scribble, box, mask} \
        --num_queirs 50 \
        --epochs 50 \
        --lr 1e-4 \
        --bsz 2     

Example qualitative results (1-shot):

BibTeX

If you use this code for your research, please consider citing:

@inproceedings{sun2024vrp,
    title={VRP-SAM: SAM with Visual Reference Prompt},
    author={Sun, Yanpeng and Chen, Jiahui and Zhang, Shan and Zhang, Xinyu and Chen, Qiang and Zhang, Gang and Ding, Errui and Wang, Jingdong and Li, Zechao},
    booktitle={Conference on Computer Vision and Pattern Recognition 2024},
    year={2024}
}

vrp-sam's People

Contributors

syp2ysy avatar jiahuichen-github avatar

Stargazers

 avatar Zhixuan CHEN avatar  avatar Jiayu Huo avatar Xuanlong Yu avatar Flying-Angels avatar Anh Nguyen Tuan avatar ficodex avatar  avatar David Tseng avatar  avatar Zong-Liang avatar  avatar  avatar Liu Lanxin avatar  avatar ryan avatar Lawrence avatar Sang avatar  avatar Wayne Tomas avatar An-zhi WANG avatar sfpeng avatar Paul van Tieghem avatar  avatar Ximi Hoque avatar Puneet Jindal avatar Zilong Zhang avatar persistence avatar Lixiang Ru avatar Chongkai Yu avatar  avatar  avatar Mayur Mistry avatar  avatar csm-kr avatar Zhe Liu avatar Lu Ming avatar Dickachu Yang avatar Xingyi Zhou avatar Qiang avatar  avatar wuxiaolian avatar Chenyi Wang avatar MillX avatar  avatar kaijieshi avatar Chenhongyi Yang avatar Wuyue avatar WeiZhang avatar  avatar  avatar  avatar YueJK avatar  avatar zhangnanyue avatar XiCi Su avatar Ren Jie avatar  avatar oyly avatar  avatar DerKunLu avatar Jarch Ma avatar  avatar YoungT avatar  avatar  avatar Ferenas avatar Xingbin Liu avatar Zikun Zhou avatar

Watchers

 avatar kaijieshi avatar Kostas Georgiou avatar Zikun Zhou avatar  avatar zhangnanyue avatar Syed Raza avatar

vrp-sam's Issues

算法效果差

你好,按你提供的脚本在coco2014数据集上训练得到的模型效果较差啊,特征提取用的是Resnet50
image
image

推理代码

感谢你完成了如此出色的工作。我想问一下,模型的推理代码是否已经编写完成了?

Model weights and Inference Code

Thank you for open sourcing the code!

I can't seem to find the trained model weights of VRP Encoder, would you be planning to open source this, or do we have to train it ourselves?

Also, would you provide a sample inference code to show how get prediction.jpg?

Thanks in advance.

More details about DINOV2 as backbone

Thanks for your work.
Could you provide more details on the implementation of DINOV2 as a backbone, specifically regarding the layers and the concatenation process

SegmentationClassAug

I want to know this SegmentationClassAug how you get it? And what is its role?
Because I want to train in main own datesets.
thanks ,I am looking forward to you !

Reproducing Paper's Result

Thanks for proposing this useful and innovative method! This is a very inspiring and impressive work.

I'm trying to reproduce the results presented in the paper. However, I'm only able to train the VRPSAM on vgg16 because the resnet50 I downloaded from 'https://download.pytorch.org/models/resnet50-19c8e357.pth' has a size mismatch error with the code. I saw in your code, you loaded:

model_path = '/root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth'

However, it seems to have a different name than the path at line 16 of resnet.py
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',

Could you please kindly provide the download link of resnet101_v2.pth? Thank you so much!

Given the one-shot results I got with vgg16 on COCO-20 fold 0, no matter which prompt type I used, they all seems worse than the result you presented in the paper, miou 43.6. Here is what I got:
Point, 28.09
Scribble, 33.43
Box, 35.84
Mask, 38.71

The parameter I used is at below:
bsz: 100
lr: 0.0001
weight_decay: 1e-06
epochs: 50
nworker: 8
seed: 321
fold: 0
use_ignore: True
num_query: 50
backbone: vgg16

Do you have any idea about what happened here? Thank you so much for answering!

Loading ResNet

I have a question regarding the code. In resnet.py starting around line 318, the code says:

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
    #model.load_state_dict(torch.utils.model_zoo.load_url(model_urls['resnet50']))
    model_path = '/root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth'
    model.load_state_dict(torch.load(model_path), strict=False)

This fails because I have nothing at root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth.

So I try to change it to:

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
    model.load_state_dict(torch.utils.model_zoo.load_url(model_urls['resnet50']))
    #model_path = '/root/paddlejob/workspace/env_run/vrp_sam/resnet50_v2.pth'
    #model.load_state_dict(torch.load(model_path), strict=False)

But now I get the following error:

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv2.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "conv3.weight", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var".
size mismatch for conv1.weight: copying a param with shape torch.Size([64, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]).
size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([64, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 128, 1, 1]).
size mismatch for layer1.0.downsample.0.weight: copying a param with shape torch.Size([256, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).

Could you provide help on how to get a ResNet model working here?

Appendix & code

Hello, thank you for your work !
I have some questions:

  1. Where can I download the appendix?
  2. When will the code be released?
    Thank you in advance for your answers !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.