Code Monkey home page Code Monkey logo

swin-unet's Introduction

Swin-Unet

[ECCVW2022] The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validation for U-shaped Swin Transformer. Our paper has been accepted by ECCV 2022 MEDICAL COMPUTER VISION WORKSHOP (https://mcv-workshop.github.io/). We updated the Reproducibility. I hope this will help you to reproduce the results.

1. Download pre-trained swin transformer model (Swin-T)

2. Prepare data

3. Environment

  • Please prepare an environment with python=3.7, and then use the command "pip install -r requirements.txt" for the dependencies.

4. Train/Test

  • Run the train script on synapse dataset. The batch size we used is 24. If you do not have enough GPU memory, the bacth size can be reduced to 12 or 6 to save memory.

  • Train

sh train.sh or python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path your DATA_DIR --max_epochs 150 --output_dir your OUT_DIR  --img_size 224 --base_lr 0.05 --batch_size 24
  • Test
sh test.sh or python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_saveni --volume_path your DATA_DIR --output_dir your OUT_DIR --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

Reproducibility

  • Questions about Dataset

Many of you have asked me for datasets, and I personally would be very glad to share the preprocessed Synapse and ACDC datasets with you. However, I am not the owner of these two preprocessed datasets. Please email jienengchen01 AT gmail.com to get the processed datasets.

  • Codes

Our trained model is stored on the Huawei cloud. The interns do not have the right to send any files out from the internal system, so I can't share our trained model weights. Regarding how to reproduce the segmentation results presented in the paper, we discovered that different GPU types would generate different results. In our code, we carefully set the random seed, so the results should be consistent when trained multiple times on the same type of GPU. If the training does not give the same segmentation results as in the paper, it is recommended to adjust the learning rate. And, the type of GPU we used in this work is Tesla v100. Finaly, pre-training is very important for pure transformer models. In our experiments, both the encoder and decoder are initialized with pretrained weights rather than initializing the encoder with pretrained weights only.

References

Citation

@InProceedings{swinunet,
author = {Hu Cao and Yueyue Wang and Joy Chen and Dongsheng Jiang and Xiaopeng Zhang and Qi Tian and Manning Wang},
title = {Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation},
booktitle = {Proceedings of the European Conference on Computer Vision Workshops(ECCVW)},
year = {2022}
}

@misc{cao2021swinunet,
      title={Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation}, 
      author={Hu Cao and Yueyue Wang and Joy Chen and Dongsheng Jiang and Xiaopeng Zhang and Qi Tian and Manning Wang},
      year={2021},
      eprint={2105.05537},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}

swin-unet's People

Contributors

hucaofighting avatar mgamz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

swin-unet's Issues

Unable to adapt to different resolution

Hi, thanks for your contribution. The model is unable to adapt to different resolutions except for 224 for both height and width. The error often occurs as

RuntimeError: shape '[xxx, xxx, xxx, xxx, xxx]' is invalid for input of size xxx

different data

hello @HuCaoFighting ,
I want to use the swin_unet for segmentation task, and the data is end with png or jpg rather than .nii.gz.
So I need to change the dataset_synapse.py, I tried many times but I failed. Coul you give me some advice?
Thanks!

Regarding Results

Hello authors,

You have done a fabulous job in implementing this approach. Thanks for making the repo publicly available.

I tried to train swin-unet with Synapse data(got it from transunet authors) but the results are not around the results you mentioned in the paper. I got mean dice of 0.780137 & mean hd95 of 30.579858.

I used the following hyper params according to your ReadME doc and config file.

--max_epochs 150 --img_size 224 --base_lr 0.05 --batch_size 24
EMBED_DIM: 96
DEPTHS: [ 2, 2, 2, 2 ] (This is quiet contradictive. In the paper it is [2, 2, 2, 1])
DECODER_DEPTHS: [ 2, 2, 2, 1]
NUM_HEADS: [ 3, 6, 12, 24 ]
WINDOW_SIZE: 7

Appreciate if you can provide the hyper params that you have used to get the results mentioned in the paper.

Thanks,
Himashi

这里的代码是不是有点问题

def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0 and gt.sum()>0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    # 预测有像素,背景没像素
    elif pred.sum() > 0 and gt.sum()==0:
        return 1, 0
    # 预测无像素,背景无像素
    else:
        return 0, 0

Revising the code to adapt to the 384*384 input image

Hi Hu Cao,

First, thanks for your wonderful work.

I'm wondering that how to revise your code to adapt to the 384x384 input image. The original Swin Transformer provides 384x384 Swin-B weight, I want to test your model using the weight.

Could you kindly provide any clue?

Thanks again.

ACDC dataset preprocessing

Hi! Thanks for your nice work!
And I'd like to ask what is the preprocessing method of the ACDC dataset? I did not find relevant information in TransUNet and Swin-Unet.

Your Patch Expand method is equivalent to transpose conv

Use kernel size=2, stride=2, in dim=2C, out dim=C transpose conv,
then transpose conv will linearly map every single patch with 2C dim feature into non overlapping 2*2 patches group with C dim feature.
That is equivalent to first linearly map 1*1@2C to 1*1@4C, then rearrange to 2*2@C ——Your method.
Only the normalization operation can possibly cause difference, for example if you insert BN between linear and rearrange, then the BN will work independently in different channels——different patches in 2*2 patches group, while BN after transpose conv will use same scale factor and bias for all patches in 2*2 patches group.

Same misunderstanding has been seen in "Understanding Convolution for Semantic Segmentation", they claim their "DUC" is different from transpose conv.
For same reason, the patch merging in swin transformer is equivalent to kernel size=2, stride=2 conv besides the norm inserted between concat and reduction.

the pretrained model state_dict did not match

I download the pretrained model, and run

python test.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --is_saveni --volume_path your DATA_DIR --output_dir your OUT_DIR --max_epoch 150 --base_lr 0.05 --img_size 224 --batch_size 24

and there are mismatch keys:

    msg = net.load_state_dict(torch.load(snapshot)["model"])
  File "/usr/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SwinUnet:
	Missing key(s) in state_dict: "swin_unet.patch_embed.proj.weight", "swin_unet.patch_embed.proj.bias", "swin_unet.patch_embed.norm.weight", "swin_unet.patch_embed.norm.bias", "swin_unet.layers.0.blocks.0.norm1.weight", "swin_unet.layers.0.blocks.0.norm1.bias", "swin_unet.layers.0.blocks.0.attn.relative_position_bias_table", "swin_unet.layers.0.blocks.0.attn.relative_position_index", "swin_unet.layers.0.blocks.0.attn.qkv.weight", "swin_unet.layers.0.blocks.0.attn.qkv.bias", "swin_unet.layers.0.blocks.0.attn.proj.weight", "swin_unet.layers.0.blocks.0.attn.proj.bias", "swin_unet.layers.0.blocks.0.norm2.weight", "swin_unet.layers.0.blocks.0.norm2.bias", "swin_unet.layers.0.blocks.0.mlp.fc1.weight", "swin_unet.layers.0.blocks.0.mlp.fc1.bias", "swin_unet.layers.0.blocks.0.mlp.fc2.weight", "swin_unet.layers.0.blocks.0.mlp.fc2.bias", "swin_unet.layers.0.blocks.1.attn_mask", "swin_unet.layers.0.blocks.1.norm1.weight", "swin_unet.layers.0.blocks.1.norm1.bias", "swin_unet.layers.0.blocks.1.attn.relative_position_bias_table", "swin_unet.layers.0.blocks.1.attn.relative_position_index", "swin_unet.layers.0.blocks.1.attn.qkv.weight", "swin_unet.layers.0.blocks.1.attn.qkv.bias", "swin_unet.layers.0.blocks.1.attn.proj.weight", "swin_unet.layers.0.blocks.1.attn.proj.bias", "swin_unet.layers.0.blocks.1.norm2.weight", "swin_unet.layers.0.blocks.1.norm2.bias", "swin_unet.layers.0.blocks.1.mlp.fc1.weight", "swin_unet.layers.0.blocks.1.mlp.fc1.bias", "swin_unet.layers.0.blocks.1.mlp.fc2.weight", "swin_unet.layers.0.blocks.1.mlp.fc2.bias", "swin_unet.layers.0.downsample.reduction.weight", "swin_unet.layers.0.downsample.norm.weight", "swin_unet.layers.0.downsample.norm.bias", "swin_unet.layers.1.blocks.0.norm1.weight", "swin_unet.layers.1.blocks.0.norm1.bias", "swin_unet.layers.1.blocks.0.attn.relative_position_bias_table", "swin_unet.layers.1.blocks.0.attn.relative_position_index", "swin_unet.layers.1.blocks.0.attn.qkv.weight", "swin_unet.layers.1.blocks.0.attn.qkv.bias", "swin_unet.layers.1.blocks.0.attn.proj.weight", "swin_unet.layers.1.blocks.0.attn.proj.bias", "swin_unet.layers.1.blocks.0.norm2.weight", "swin_unet.layers.1.blocks.0.norm2.bias", "swin_unet.layers.1.blocks.0.mlp.fc1.weight", "swin_unet.layers.1.blocks.0.mlp.fc1.bias", "swin_unet.layers.1.blocks.0.mlp.fc2.weight", "swin_unet.layers.1.blocks.0.mlp.fc2.bias", "swin_unet.layers.1.blocks.1.attn_mask", "swin_unet.layers.1.blocks.1.norm1.weight", "swin_unet.layers.1.blocks.1.norm1.bias", "swin_unet.layers.1.blocks.1.attn.relative_position_bias_table", "swin_unet.layers.1.blocks.1.attn.relative_position_index", "swin_unet.layers.1.blocks.1.attn.qkv.weight", "swin_unet.layers.1.blocks.1.attn.qkv.bias", "swin_unet.layers.1.blocks.1.attn.proj.weight", "swin_unet.layers.1.blocks.1.attn.proj.bias", "swin_unet.layers.1.blocks.1.norm2.weight", "swin_unet.layers.1.blocks.1.norm2.bias", "swin_unet.layers.1.blocks.1.mlp.fc1.weight", "swin_unet.layers.1.blocks.1.mlp.fc1.bias", "swin_unet.layers.1.blocks.1.mlp.fc2.weight", "swin_unet.layers.1.blocks.1.mlp.fc2.bias", "swin_unet.layers.1.downsample.reduction.weight", "swin_unet.layers.1.downsample.norm.weight", "swin_unet.layers.1.downsample.norm.bias", "swin_unet.layers.2.blocks.0.norm1.weight", "swin_unet.layers.2.blocks.0.norm1.bias", "swin_unet.layers.2.blocks.0.attn.relative_position_bias_table", "swin_unet.layers.2.blocks.0.attn.relative_position_index", "swin_unet.layers.2.blocks.0.attn.qkv.weight", "swin_unet.layers.2.blocks.0.attn.qkv.bias", "swin_unet.layers.2.blocks.0.attn.proj.weight", "swin_unet.layers.2.blocks.0.attn.proj.bias", "swin_unet.layers.2.blocks.0.norm2.weight", "swin_unet.layers.2.blocks.0.norm2.bias", "swin_unet.layers.2.blocks.0.mlp.fc1.weight", "swin_unet.layers.2.blocks.0.mlp.fc1.bias", "swin_unet.layers.2.blocks.0.mlp.fc2.weight", "swin_unet.layers.2.blocks.0.mlp.fc2.bias", "swin_unet.layers.2.blocks.1.attn_mask", "swin_unet.layers.2.blocks.1.norm1.weight", "swin_unet.layers.2.blocks.1.norm1.bias", "swin_unet.layers.2.blocks.1.attn.relative_position_bias_table", "swin_unet.layers.2.blocks.1.attn.relative_position_index", "swin_unet.layers.2.blocks.1.attn.qkv.weight", "swin_unet.layers.2.blocks.1.attn.qkv.bias", "swin_unet.layers.2.blocks.1.attn.proj.weight", "swin_unet.layers.2.blocks.1.attn.proj.bias", "swin_unet.layers.2.blocks.1.norm2.weight", "swin_unet.layers.2.blocks.1.norm2.bias", "swin_unet.layers.2.blocks.1.mlp.fc1.weight", "swin_unet.layers.2.blocks.1.mlp.fc1.bias", "swin_unet.layers.2.blocks.1.mlp.fc2.weight", "swin_unet.layers.2.blocks.1.mlp.fc2.bias", "swin_unet.layers.2.downsample.reduction.weight", "swin_unet.layers.2.downsample.norm.weight", "swin_unet.layers.2.downsample.norm.bias", "swin_unet.layers.3.blocks.0.norm1.weight", "swin_unet.layers.3.blocks.0.norm1.bias", "swin_unet.layers.3.blocks.0.attn.relative_position_bias_table", "swin_unet.layers.3.blocks.0.attn.relative_position_index", "swin_unet.layers.3.blocks.0.attn.qkv.weight", "swin_unet.layers.3.blocks.0.attn.qkv.bias", "swin_unet.layers.3.blocks.0.attn.proj.weight", "swin_unet.layers.3.blocks.0.attn.proj.bias", "swin_unet.layers.3.blocks.0.norm2.weight", "swin_unet.layers.3.blocks.0.norm2.bias", "swin_unet.layers.3.blocks.0.mlp.fc1.weight", "swin_unet.layers.3.blocks.0.mlp.fc1.bias", "swin_unet.layers.3.blocks.0.mlp.fc2.weight", "swin_unet.layers.3.blocks.0.mlp.fc2.bias", "swin_unet.layers.3.blocks.1.norm1.weight", "swin_unet.layers.3.blocks.1.norm1.bias", "swin_unet.layers.3.blocks.1.attn.relative_position_bias_table", "swin_unet.layers.3.blocks.1.attn.relative_position_index", "swin_unet.layers.3.blocks.1.attn.qkv.weight", "swin_unet.layers.3.blocks.1.attn.qkv.bias", "swin_unet.layers.3.blocks.1.attn.proj.weight", "swin_unet.layers.3.blocks.1.attn.proj.bias", "swin_unet.layers.3.blocks.1.norm2.weight", "swin_unet.layers.3.blocks.1.norm2.bias", "swin_unet.layers.3.blocks.1.mlp.fc1.weight", "swin_unet.layers.3.blocks.1.mlp.fc1.bias", "swin_unet.layers.3.blocks.1.mlp.fc2.weight", "swin_unet.layers.3.blocks.1.mlp.fc2.bias", "swin_unet.layers_up.0.expand.weight", "swin_unet.layers_up.0.norm.weight", "swin_unet.layers_up.0.norm.bias", "swin_unet.layers_up.1.blocks.0.norm1.weight", "swin_unet.layers_up.1.blocks.0.norm1.bias", "swin_unet.layers_up.1.blocks.0.attn.relative_position_bias_table", "swin_unet.layers_up.1.blocks.0.attn.relative_position_index", "swin_unet.layers_up.1.blocks.0.attn.qkv.weight", "swin_unet.layers_up.1.blocks.0.attn.qkv.bias", "swin_unet.layers_up.1.blocks.0.attn.proj.weight", "swin_unet.layers_up.1.blocks.0.attn.proj.bias", "swin_unet.layers_up.1.blocks.0.norm2.weight", "swin_unet.layers_up.1.blocks.0.norm2.bias", "swin_unet.layers_up.1.blocks.0.mlp.fc1.weight", "swin_unet.layers_up.1.blocks.0.mlp.fc1.bias", "swin_unet.layers_up.1.blocks.0.mlp.fc2.weight", "swin_unet.layers_up.1.blocks.0.mlp.fc2.bias", "swin_unet.layers_up.1.blocks.1.attn_mask", "swin_unet.layers_up.1.blocks.1.norm1.weight", "swin_unet.layers_up.1.blocks.1.norm1.bias", "swin_unet.layers_up.1.blocks.1.attn.relative_position_bias_table", "swin_unet.layers_up.1.blocks.1.attn.relative_position_index", "swin_unet.layers_up.1.blocks.1.attn.qkv.weight", "swin_unet.layers_up.1.blocks.1.attn.qkv.bias", "swin_unet.layers_up.1.blocks.1.attn.proj.weight", "swin_unet.layers_up.1.blocks.1.attn.proj.bias", "swin_unet.layers_up.1.blocks.1.norm2.weight", "swin_unet.layers_up.1.blocks.1.norm2.bias", "swin_unet.layers_up.1.blocks.1.mlp.fc1.weight", "swin_unet.layers_up.1.blocks.1.mlp.fc1.bias", "swin_unet.layers_up.1.blocks.1.mlp.fc2.weight", "swin_unet.layers_up.1.blocks.1.mlp.fc2.bias", "swin_unet.layers_up.1.upsample.expand.weight", "swin_unet.layers_up.1.upsample.norm.weight", "swin_unet.layers_up.1.upsample.norm.bias", "swin_unet.layers_up.2.blocks.0.norm1.weight", "swin_unet.layers_up.2.blocks.0.norm1.bias", "swin_unet.layers_up.2.blocks.0.attn.relative_position_bias_table", "swin_unet.layers_up.2.blocks.0.attn.relative_position_index", "swin_unet.layers_up.2.blocks.0.attn.qkv.weight", "swin_unet.layers_up.2.blocks.0.attn.qkv.bias", "swin_unet.layers_up.2.blocks.0.attn.proj.weight", "swin_unet.layers_up.2.blocks.0.attn.proj.bias", "swin_unet.layers_up.2.blocks.0.norm2.weight", "swin_unet.layers_up.2.blocks.0.norm2.bias", "swin_unet.layers_up.2.blocks.0.mlp.fc1.weight", "swin_unet.layers_up.2.blocks.0.mlp.fc1.bias", "swin_unet.layers_up.2.blocks.0.mlp.fc2.weight", "swin_unet.layers_up.2.blocks.0.mlp.fc2.bias", "swin_unet.layers_up.2.blocks.1.attn_mask", "swin_unet.layers_up.2.blocks.1.norm1.weight", "swin_unet.layers_up.2.blocks.1.norm1.bias", "swin_unet.layers_up.2.blocks.1.attn.relative_position_bias_table", "swin_unet.layers_up.2.blocks.1.attn.relative_position_index", "swin_unet.layers_up.2.blocks.1.attn.qkv.weight", "swin_unet.layers_up.2.blocks.1.attn.qkv.bias", "swin_unet.layers_up.2.blocks.1.attn.proj.weight", "swin_unet.layers_up.2.blocks.1.attn.proj.bias", "swin_unet.layers_up.2.blocks.1.norm2.weight", "swin_unet.layers_up.2.blocks.1.norm2.bias", "swin_unet.layers_up.2.blocks.1.mlp.fc1.weight", "swin_unet.layers_up.2.blocks.1.mlp.fc1.bias", "swin_unet.layers_up.2.blocks.1.mlp.fc2.weight", "swin_unet.layers_up.2.blocks.1.mlp.fc2.bias", "swin_unet.layers_up.2.upsample.expand.weight", "swin_unet.layers_up.2.upsample.norm.weight", "swin_unet.layers_up.2.upsample.norm.bias", "swin_unet.layers_up.3.blocks.0.norm1.weight", "swin_unet.layers_up.3.blocks.0.norm1.bias", "swin_unet.layers_up.3.blocks.0.attn.relative_position_bias_table", "swin_unet.layers_up.3.blocks.0.attn.relative_position_index", "swin_unet.layers_up.3.blocks.0.attn.qkv.weight", "swin_unet.layers_up.3.blocks.0.attn.qkv.bias", "swin_unet.layers_up.3.blocks.0.attn.proj.weight", "swin_unet.layers_up.3.blocks.0.attn.proj.bias", "swin_unet.layers_up.3.blocks.0.norm2.weight", "swin_unet.layers_up.3.blocks.0.norm2.bias", "swin_unet.layers_up.3.blocks.0.mlp.fc1.weight", "swin_unet.layers_up.3.blocks.0.mlp.fc1.bias", "swin_unet.layers_up.3.blocks.0.mlp.fc2.weight", "swin_unet.layers_up.3.blocks.0.mlp.fc2.bias", "swin_unet.layers_up.3.blocks.1.attn_mask", "swin_unet.layers_up.3.blocks.1.norm1.weight", "swin_unet.layers_up.3.blocks.1.norm1.bias", "swin_unet.layers_up.3.blocks.1.attn.relative_position_bias_table", "swin_unet.layers_up.3.blocks.1.attn.relative_position_index", "swin_unet.layers_up.3.blocks.1.attn.qkv.weight", "swin_unet.layers_up.3.blocks.1.attn.qkv.bias", "swin_unet.layers_up.3.blocks.1.attn.proj.weight", "swin_unet.layers_up.3.blocks.1.attn.proj.bias", "swin_unet.layers_up.3.blocks.1.norm2.weight", "swin_unet.layers_up.3.blocks.1.norm2.bias", "swin_unet.layers_up.3.blocks.1.mlp.fc1.weight", "swin_unet.layers_up.3.blocks.1.mlp.fc1.bias", "swin_unet.layers_up.3.blocks.1.mlp.fc2.weight", "swin_unet.layers_up.3.blocks.1.mlp.fc2.bias", "swin_unet.concat_back_dim.1.weight", "swin_unet.concat_back_dim.1.bias", "swin_unet.concat_back_dim.2.weight", "swin_unet.concat_back_dim.2.bias", "swin_unet.concat_back_dim.3.weight", "swin_unet.concat_back_dim.3.bias", "swin_unet.norm.weight", "swin_unet.norm.bias", "swin_unet.norm_up.weight", "swin_unet.norm_up.bias", "swin_unet.up.expand.weight", "swin_unet.up.norm.weight", "swin_unet.up.norm.bias", "swin_unet.output.weight". 
	Unexpected key(s) in state_dict: "patch_embed.proj.weight", "patch_embed.proj.bias", "patch_embed.norm.weight", "patch_embed.norm.bias", "layers.0.blocks.0.norm1.weight", "layers.0.blocks.0.norm1.bias", "layers.0.blocks.0.attn.qkv.weight", "layers.0.blocks.0.attn.qkv.bias", "layers.0.blocks.0.attn.proj.weight", "layers.0.blocks.0.attn.proj.bias", "layers.0.blocks.0.norm2.weight", "layers.0.blocks.0.norm2.bias", "layers.0.blocks.0.mlp.fc1.weight", "layers.0.blocks.0.mlp.fc1.bias", "layers.0.blocks.0.mlp.fc2.weight", "layers.0.blocks.0.mlp.fc2.bias", "layers.0.blocks.1.norm1.weight", "layers.0.blocks.1.norm1.bias", "layers.0.blocks.1.attn.qkv.weight", "layers.0.blocks.1.attn.qkv.bias", "layers.0.blocks.1.attn.proj.weight", "layers.0.blocks.1.attn.proj.bias", "layers.0.blocks.1.norm2.weight", "layers.0.blocks.1.norm2.bias", "layers.0.blocks.1.mlp.fc1.weight", "layers.0.blocks.1.mlp.fc1.bias", "layers.0.blocks.1.mlp.fc2.weight", "layers.0.blocks.1.mlp.fc2.bias", "layers.0.downsample.norm.weight", "layers.0.downsample.norm.bias", "layers.1.blocks.0.norm1.weight", "layers.1.blocks.0.norm1.bias", "layers.1.blocks.0.attn.qkv.weight", "layers.1.blocks.0.attn.qkv.bias", "layers.1.blocks.0.attn.proj.weight", "layers.1.blocks.0.attn.proj.bias", "layers.1.blocks.0.norm2.weight", "layers.1.blocks.0.norm2.bias", "layers.1.blocks.0.mlp.fc1.weight", "layers.1.blocks.0.mlp.fc1.bias", "layers.1.blocks.0.mlp.fc2.weight", "layers.1.blocks.0.mlp.fc2.bias", "layers.1.blocks.1.norm1.weight", "layers.1.blocks.1.norm1.bias", "layers.1.blocks.1.attn.qkv.weight", "layers.1.blocks.1.attn.qkv.bias", "layers.1.blocks.1.attn.proj.weight", "layers.1.blocks.1.attn.proj.bias", "layers.1.blocks.1.norm2.weight", "layers.1.blocks.1.norm2.bias", "layers.1.blocks.1.mlp.fc1.weight", "layers.1.blocks.1.mlp.fc1.bias", "layers.1.blocks.1.mlp.fc2.weight", "layers.1.blocks.1.mlp.fc2.bias", "layers.1.downsample.norm.weight", "layers.1.downsample.norm.bias", "layers.2.blocks.0.norm1.weight", "layers.2.blocks.0.norm1.bias", "layers.2.blocks.0.attn.qkv.weight", "layers.2.blocks.0.attn.qkv.bias", "layers.2.blocks.0.attn.proj.weight", "layers.2.blocks.0.attn.proj.bias", "layers.2.blocks.0.norm2.weight", "layers.2.blocks.0.norm2.bias", "layers.2.blocks.0.mlp.fc1.weight", "layers.2.blocks.0.mlp.fc1.bias", "layers.2.blocks.0.mlp.fc2.weight", "layers.2.blocks.0.mlp.fc2.bias", "layers.2.blocks.1.norm1.weight", "layers.2.blocks.1.norm1.bias", "layers.2.blocks.1.attn.qkv.weight", "layers.2.blocks.1.attn.qkv.bias", "layers.2.blocks.1.attn.proj.weight", "layers.2.blocks.1.attn.proj.bias", "layers.2.blocks.1.norm2.weight", "layers.2.blocks.1.norm2.bias", "layers.2.blocks.1.mlp.fc1.weight", "layers.2.blocks.1.mlp.fc1.bias", "layers.2.blocks.1.mlp.fc2.weight", "layers.2.blocks.1.mlp.fc2.bias", "layers.2.blocks.2.norm1.weight", "layers.2.blocks.2.norm1.bias", "layers.2.blocks.2.attn.qkv.weight", "layers.2.blocks.2.attn.qkv.bias", "layers.2.blocks.2.attn.proj.weight", "layers.2.blocks.2.attn.proj.bias", "layers.2.blocks.2.norm2.weight", "layers.2.blocks.2.norm2.bias", "layers.2.blocks.2.mlp.fc1.weight", "layers.2.blocks.2.mlp.fc1.bias", "layers.2.blocks.2.mlp.fc2.weight", "layers.2.blocks.2.mlp.fc2.bias", "layers.2.blocks.3.norm1.weight", "layers.2.blocks.3.norm1.bias", "layers.2.blocks.3.attn.qkv.weight", "layers.2.blocks.3.attn.qkv.bias", "layers.2.blocks.3.attn.proj.weight", "layers.2.blocks.3.attn.proj.bias", "layers.2.blocks.3.norm2.weight", "layers.2.blocks.3.norm2.bias", "layers.2.blocks.3.mlp.fc1.weight", "layers.2.blocks.3.mlp.fc1.bias", "layers.2.blocks.3.mlp.fc2.weight", "layers.2.blocks.3.mlp.fc2.bias", "layers.2.blocks.4.norm1.weight", "layers.2.blocks.4.norm1.bias", "layers.2.blocks.4.attn.qkv.weight", "layers.2.blocks.4.attn.qkv.bias", "layers.2.blocks.4.attn.proj.weight", "layers.2.blocks.4.attn.proj.bias", "layers.2.blocks.4.norm2.weight", "layers.2.blocks.4.norm2.bias", "layers.2.blocks.4.mlp.fc1.weight", "layers.2.blocks.4.mlp.fc1.bias", "layers.2.blocks.4.mlp.fc2.weight", "layers.2.blocks.4.mlp.fc2.bias", "layers.2.blocks.5.norm1.weight", "layers.2.blocks.5.norm1.bias", "layers.2.blocks.5.attn.qkv.weight", "layers.2.blocks.5.attn.qkv.bias", "layers.2.blocks.5.attn.proj.weight", "layers.2.blocks.5.attn.proj.bias", "layers.2.blocks.5.norm2.weight", "layers.2.blocks.5.norm2.bias", "layers.2.blocks.5.mlp.fc1.weight", "layers.2.blocks.5.mlp.fc1.bias", "layers.2.blocks.5.mlp.fc2.weight", "layers.2.blocks.5.mlp.fc2.bias", "layers.2.downsample.norm.weight", "layers.2.downsample.norm.bias", "layers.3.blocks.0.norm1.weight", "layers.3.blocks.0.norm1.bias", "layers.3.blocks.0.attn.qkv.weight", "layers.3.blocks.0.attn.qkv.bias", "layers.3.blocks.0.attn.proj.weight", "layers.3.blocks.0.attn.proj.bias", "layers.3.blocks.0.norm2.weight", "layers.3.blocks.0.norm2.bias", "layers.3.blocks.0.mlp.fc1.weight", "layers.3.blocks.0.mlp.fc1.bias", "layers.3.blocks.0.mlp.fc2.weight", "layers.3.blocks.0.mlp.fc2.bias", "layers.3.blocks.1.norm1.weight", "layers.3.blocks.1.norm1.bias", "layers.3.blocks.1.attn.qkv.weight", "layers.3.blocks.1.attn.qkv.bias", "layers.3.blocks.1.attn.proj.weight", "layers.3.blocks.1.attn.proj.bias", "layers.3.blocks.1.norm2.weight", "layers.3.blocks.1.norm2.bias", "layers.3.blocks.1.mlp.fc1.weight", "layers.3.blocks.1.mlp.fc1.bias", "layers.3.blocks.1.mlp.fc2.weight", "layers.3.blocks.1.mlp.fc2.bias", "norm.weight", "norm.bias", "head.weight", "head.bias", "layers.0.blocks.0.attn.relative_position_index", "layers.0.blocks.1.attn.relative_position_index", "layers.1.blocks.0.attn.relative_position_index", "layers.1.blocks.1.attn.relative_position_index", "layers.2.blocks.0.attn.relative_position_index", "layers.2.blocks.1.attn.relative_position_index", "layers.2.blocks.2.attn.relative_position_index", "layers.2.blocks.3.attn.relative_position_index", "layers.2.blocks.4.attn.relative_position_index", "layers.2.blocks.5.attn.relative_position_index", "layers.3.blocks.0.attn.relative_position_index", "layers.3.blocks.1.attn.relative_position_index", "layers.0.blocks.1.attn_mask", "layers.1.blocks.1.attn_mask", "layers.2.blocks.1.attn_mask", "layers.2.blocks.3.attn_mask", "layers.2.blocks.5.attn_mask", "layers.0.blocks.0.attn.relative_position_bias_table", "layers.0.blocks.1.attn.relative_position_bias_table", "layers.1.blocks.0.attn.relative_position_bias_table", "layers.1.blocks.1.attn.relative_position_bias_table", "layers.2.blocks.0.attn.relative_position_bias_table", "layers.2.blocks.1.attn.relative_position_bias_table", "layers.2.blocks.2.attn.relative_position_bias_table", "layers.2.blocks.3.attn.relative_position_bias_table", "layers.2.blocks.4.attn.relative_position_bias_table", "layers.2.blocks.5.attn.relative_position_bias_table", "layers.3.blocks.0.attn.relative_position_bias_table", "layers.3.blocks.1.attn.relative_position_bias_table", "layers.0.downsample.reduction.weight", "layers.1.downsample.reduction.weight", "layers.2.downsample.reduction.weight". 

Where and how to put the labeled data?

Hi! I hope you're great. First of all thanks for your great job!
My question is the following:
I downloaded the Synapse dataset but I don't know where and in which format do I have to put the training data with each corresponding labeled.
I have two folders: averaged-training-images containing files like DET0001101_avg.nii.gz and another one averaged-training-labels with the format DET0001101_avg_seg.nii.gz.
If I understand right, I'm using --root_path averaged-training-images, but where are the labels expected to be?
Thanks for any help!

源码呢?

啥时候上传源码?我想学习学习。。

torch

The torch version is too high with an error,what is your cuda version?

code

where is the code?

Great project! But I have some problems;

Would you please tell me the best performence that you trained?
I have trained the model ith default parameters in Synapse datasets, but the performence is not good, as below:

Testing performance in best val model: mean_dice : 0.760341 mean_hd95 : 29.199245

It is really diffirent with the paper, how should I do?
Thank you!

How is Unet trained?

Hi there!
Thank you for your innovative work.
I want to ask how your Unet and attention-unet are trained. I tried many times and I couldn’t achieve that effect.

Looking forward to your reply.
Best wish!!!

Further experiments advise

Nice work!

The listed experiments shown the U-shape Swin Transformer could capture global and local information. However, its ability to capture small target information has not yet been shown, i.e, further experiment on lung nodule segmentation task using LIDC dataset would be helpful.

Can you share the trained model?

Thanks for having shared your work. Could also make available the model you have trained? I would like to reproduce and cite your work in a paper.

ACDC数据的随机划分

你好,数据我已经下载。我看论文中,关于此ACDC数据集的划分是随机进行划分的,为了更好的对比结果,方便分享一下你们的数据集划分和预处理情况吗?如果有训练和测试代码更好,[email protected],非常感谢

Bug!!!

root_path的默认值为default='../data/Synapse/train_npz'
但是
if args.dataset == "Synapse":
args.root_path = os.path.join(args.root_path, "train_npz")
config = get_config(args)
在这里又在path后面添加了train_npz,导致重复。

有一些需要的包没有在requirements.txt中

model layers missing

File swin_transformer_unet_skip_expand_decoder_sys.py contains the

from timm.models.layers import DropPath, to_2tuple, trunc_normal_ However timm folder is not available

License file is missing

Hello,

Your code looks very nice.

Is it possible to open your code with an open source license?

Cheers,

ACDC dataset

ACDC's preprocessed dataset is not publicly available, but could you open the PyTorch code of ACDC's dataset like dataset_synapse.py

config.py _C.MODEL.SWIN.QK_SCALE = None

Traceback (most recent call last):
File "train.py", line 10, in
from config import get_config
File "C:\Users\Downloads\Swin-Unet-main\config.py", line 72, in
_C.MODEL.SWIN.QK_SCALE = None
File "C:\anaconda\envs\swin\lib\site-packages\yacs\config.py", line 158, in setattr
type(value), name, _VALID_TYPES
File "C:\anaconda\envs\swin\lib\site-packages\yacs\config.py", line 521, in _assert_with_logging
assert cond, msg
AssertionError: Invalid type <class 'NoneType'> for key QK_SCALE; valid types = {<class 'float'>, <class 'tuple'>, <class 'str'>, <class 'list'>, <class 'bool'>, <class 'int'>}

Some questions on training own data

Hi Hu Cao:
Thanks for your great work and I am very interested in swin-unet. I use this git repo to train my own data but the predicted segmentation mask seems like composed of image patches and the border is blurry. Can you provide some suggestions? Thank you. My config is as follows:
img_size:224
depth: [2,2,2,2]
window_size:7
patch_size:4

Great job! But pretrained_ckpt load erroe;

When I train with my datasets, the error occured as below:

=> merge config from configs/swin_tiny_patch4_window7_224_lite.yaml
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:9
---final upsample expand_first---
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth
Traceback (most recent call last):
  File "train.py", line 96, in <module>
    net.load_from(config)
  File "/mnt/e/projects/Sementic_Segmentation/Swin-Unet-PyTorch/networks/vision_transformer.py", line 58, in load_from
    pretrained_dict = torch.load(pretrained_path, map_location=device)
  File "/home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/serialization.py", line 527, in load
    with _open_zipfile_reader(f) as opened_zipfile:
  File "/home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/serialization.py", line 224, in __init__
    super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:132, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 2. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:132)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x33 (0x7fd2bbfab193 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1f5b (0x7fd2bf1339eb in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::string const&) + 0x64 (0x7fd2bf134c04 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x6c6536 (0x7fd307246536 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x295a74 (0x7fd306e15a74 in /home/byronnar/anaconda3/envs/swin_unet_torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #33: __libc_start_main + 0xf3 (0x7fd319cf10b3 in /lib/x86_64-linux-gnu/libc.so.6)

How should I do? Thank you!

Test gets stuck

I'm trying to reproduce the experiment. The training stage looks good. However, the test stage gets stuck here around half an hour. Is this normal? I also wonder how long it will take to finish the evaluation? Thanks!
image

How to change to a larger model?

Hello, your code looks very nice, is it possible to change to a larger pre-trained model and related configuration in your code? For example, Swin-S/B/L?

about pre training weight

Your model structure is not consistent with swin. How do you use his pre training weight? Don't you use pre training weight

Attention Mask

Hi! Great work! I was just wondering, what is the purpose of the attention mask that is been added to the attention weights in SW-MSA block? Thanks in advance.

Reading different sized images & also issues with upsampling

Hi, thanks for your great work. I have been trying to train the swin unet, however, I'm getting two issues:

  • If I specify an image size (e.g. 1080), the network can't take anything smaller than this. A solution would be to use bilinear interpolation to resize, but this would be very memory inefficient and not good for training. For example, if I had a 200x200 image I'd have to resize it for the network to process it. So if there a way to train on smaller images (e.g. the 200x200) and then test on larger ones (1080).

  • Also, is there a way for the input image to not be square?

  • When training the network, if the input was (4,200,200) and the number of classes was 4, I keep on getting (4,32,32) and not (4,200,200) as the output of the unet so upsampling is not properly occurring. I looked at the code for swin unet to try and figure out what the problem was but to no avail.

Thanks

Question

Hello author.Thanks for your code.when I run your code with private data, error like''RuntimeError: CUDA error: device-side assert triggered'' happens. What should I do to deal with this issue?

Is it convenient to provide a preprocessed data file?

Hello, I have sent an email to the mailbox in the Readme, hoping to get a copy of the preprocessed data, but unfortunately I did not reply. So ask again on Issue, is it convenient to provide a preprocessed data file?Thank you very much if you can reply to the email

几个问题

good job

  1. pixshuffle
    我这边也基于swin进行类似Unet的改进,当时decoder用的是pixshuffle。想知道你这个Patch expanding layer 和pixshuffle有什么不同之处?
  2. 实验结果
    我看到你的实验部分结果大部分和transunet那篇文章是一致的。但是我看他的数据集是随机划分的训练集和测试集。你这边随机出来的数据和他的是一样的吗?如果不是,那结果怎么可以放在一起比较?你不得跑一下其他算法做一下对比吗?

关于训练自己的数据

您好,我对贵团队的工作非常感兴趣,最近想尝试使用Swin-Unet训练自己的数据,但是您在readme里面的介绍好像是借用TransUnet的工作,能不能在现在工作的基础上添加上生成自己的数据相关的介绍或者实例代码呢?非常感谢

please help me

=> merge config from configs/swin_tiny_patch4_window7_224_lite.yaml
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.2;num_classes:9
---final upsample expand_first---
pretrained_path:./pretrained_ckpt/swin_tiny_patch4_window7_224.pth
---start load pretrained modle of swin encoder---
Namespace(accumulation_steps=None, amp_opt_level='O1', base_lr=0.01, batch_size=1, cache_mode='part', cfg='configs/swin_tiny_patch4_window7_224_lite.yaml', dataset='Synapse', deterministic=1, eval=False, img_size=12, list_dir='./lists/lists_Synapse', max_epochs=150, max_iterations=30000, n_gpu=1, num_classes=9, opts=None, output_dir='.\model_out\', resume=None, root_path='.\datasets\Synapse\train_npz', seed=1234, tag=None, throughput=False, use_checkpoint=False, zip=False)
The length of train set is: 2211
2211 iterations per epoch. 331650 max iterations
0%| | 0/150 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 99, in
trainer[dataset_name](args, net, args.output_dir)
File "C:\Users\Downloads\Swin-Unet-main\trainer.py", line 53, in trainer_synapse
for i_batch, sampled_batch in enumerate(trainloader):
File "C:\anaconda\envs\swin\lib\site-packages\torch\utils\data\dataloader.py", line 291, in iter
return _MultiProcessingDataLoaderIter(self)
File "C:\anaconda\envs\swin\lib\site-packages\torch\utils\data\dataloader.py", line 737, in init
w.start()
File "C:\anaconda\envs\swin\lib\multiprocessing\process.py", line 112, in start
self._popen = self._Popen(self)
File "C:\anaconda\envs\swin\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "C:\anaconda\envs\swin\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "C:\anaconda\envs\swin\lib\multiprocessing\popen_spawn_win32.py", line 89, in init
reduction.dump(process_obj, to_child)
File "C:\anaconda\envs\swin\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'trainer_synapse..worker_init_fn'

对于使用代码的许可

您好,您团队的工作精致且富有成效。
我是ZJU的一名学生,我想基于您团队的研究成果参加一个学术性的桥梁语义分割比赛,因此来请求许可。
同时,我的目标是从19201080的image中分割出640360(即1/9像素)级别的语义分割即可,想请问对于这样的需求,您有什么建议吗

Resume from saved Weight

I am using google Colab to train the model, but because of the time restriction, I can only train up to epoch 99. Is there any way to resume from epoch 99??

Note:
I used --resume epoch_99.pth, but it isn't loading the previous weights. Am I using it wrong? Is there any other way to do it?
Thank you.

patch expanding layer

请问能否提供下分辨率扩展这部分代码,想借鉴一下这部分,因为发现您的论文里面效果比转置卷积好。

Patch expanding can not reverse the patch merging

The PatchMerging I believe comes from the Swin codes, as below:

x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C

And the PatchExpand is below:

x = x.view(B, H, W, C)
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)

I think if I put an tensor after PatchMerging and PatchExpand, I can not get the original tensor.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.