Code Monkey home page Code Monkey logo

sam-med3d's People

Contributors

blueyo0 avatar k-chrispens avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

sam-med3d's Issues

requirements?

Hi, thanks for sharing this work!
I'd like to know if you can provide the requirements.txt file for building environment?

Request for validation set list

I hope you're doing well. I recently read your paper and found it to be very informative and valuable.

I noticed that the validation set list was not provided in the paper or the supplementary materials. Would it be possible for you to share the list of samples used in the validation set? Knowing this information helps me conduct a comparative experiment.

Thank you for your time.

fix up always dice=0 and find a new bug

之前一直被dice为0这个bug所困扰,所以下定决心给它修复了,同时在修bug的时候我发现,在打印每个epoch的平均loss值时,step的个数少1,导致每一轮打印出来的loss值偏高。不知道我找到的这个bug是不是正确的,还请作者团队看一下
我还不是很擅长使用git,我就把修改的代码赋在下面了

    def train_epoch(self, epoch, num_clicks):
        epoch_loss = 0
        epoch_iou = 0
        epoch_dice = 0
        self.model.train()
        if self.args.multi_gpu:
            sam_model = self.model.module
        else:
            sam_model = self.model
            self.args.rank = -1
        
        if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
            tbar = tqdm(self.dataloaders)
        else:
            tbar = self.dataloaders

        self.optimizer.zero_grad()
        step_loss = 0
        epoch_dice = 0      #change
        for step, (image3D, gt3D) in enumerate(tbar):

            my_context = self.model.no_sync if self.args.rank != -1 and step % self.args.accumulation_steps != 0 else nullcontext

            with my_context():

                image3D = self.norm_transform(image3D.squeeze(dim=1)) # (N, C, W, H, D)
                image3D = image3D.unsqueeze(dim=1)
                
                image3D = image3D.to(device)
                gt3D = gt3D.to(device).type(torch.long)
                with amp.autocast():
                    image_embedding = sam_model.image_encoder(image3D)

                    self.click_points = []
                    self.click_labels = []

                    pred_list = []

                    prev_masks, loss = self.interaction(sam_model, image_embedding, gt3D, num_clicks=11)                

                epoch_loss += loss.item()

                epoch_dice += self.get_dice_score(prev_masks,gt3D)      #change

                cur_loss = loss.item()

                loss /= self.args.accumulation_steps
                
                self.scaler.scale(loss).backward()    

            if step % self.args.accumulation_steps == 0 and step != 0:
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                print_loss = step_loss / self.args.accumulation_steps
                step_loss = 0
                print_dice = self.get_dice_score(prev_masks, gt3D)
                
            else:
                step_loss += cur_loss

            if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
                if step % self.args.accumulation_steps == 0 and step != 0:
                    print(f'Epoch: {epoch}, Step: {step}, Loss: {print_loss}, Dice: {print_dice}')
                    if print_dice > self.step_best_dice:
                        self.step_best_dice = print_dice
                        if print_dice > 0.9:
                            self.save_checkpoint(
                                epoch,
                                sam_model.state_dict(),
                                describe=f'{epoch}_step_dice:{print_dice}_best'
                            )
                    if print_loss < self.step_best_loss:
                        self.step_best_loss = print_loss
            
        epoch_loss /= step+1        #change
        epoch_dice /= step+1        #change

        return epoch_loss, epoch_iou, epoch_dice, pred_list

数据准备有问题,作者可以帮忙看看吗

Traceback (most recent call last):
File "C:\Users\pl\Desktop\FILE\SAM-Med3D-main\utils\prepare_data_from_nnUNet.py", line 70, in
for idx, cls_name in meta_info[""].items():
KeyError: ''
HepaticVessel {'0': 'CT'}
num_classes: 2 {'0': 'background', '1': 'Vessel', '2': 'Tumour'}
我试了好几个数据集都不行,dataset.json里也没有空字符啊,很奇怪

BraTS2021 dataset preprocessing

First, thank you very much for sharing this amazing project and source code. Really helpful for my study !!

Sorry but i have a question about BraTS2021 dataset preprocessing.
Can you share the details about preparing BraTS_Val* ??

When i run prepare_data_from_nnUNet.py for my BraTS2021 dataset, every 'background' class data (whose volume is all 0) is skipped so that doesn`t save in the dataset folder.

So i removed the code here in prepare_data_from_nnUNet.py

[ if(volume<10):
print("skip", target_img_path)
continue
]
and run train.py. However it has some problems (suddenly the loss is nan, but Dice score is not 0.)

(I check the gt3D consisted of all 0 values, and not sure but guess this is cause of 'background' data)

So i want to know how to construct BraTS 21 dataset.

Thank you again and hope to get an answer !
(If it is my personal issue, sorry in advance for asking bad question ;)) )

Multi-GPU is not working

Hello,

I tried to use multiple GPUs by setting --multi_gpu, however, only GPU 0 was used during training according to the outputs of nvidia-smi command.

CUDA out of memory

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.68 GiB (GPU 1; 31.75 GiB total capacity; 27.51 GiB already allocated; 1.65 GiB free; 28.90 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

How can I solve this problem if i only have V100 GPU ? thanks !

Regarding the data processing details of the infer_transform operation in validation.py.

Hi,
In the validation.py file, I noticed that the data processing includes the infer_transform operation, which sets mask_name='label'. I would like to ask how an image without a mask label is handled and input into the model for prediction, and whether this handling affects the performance of the segmentation results.
截屏2023-10-31 18 25 36
I'm not sure if my understanding of this part of the code is correct. I would appreciate your response and clarification. Thank you.

SAM-Med 2D Inference

I'm trying to use the given code for 3D inference with SAM-Med 2D, but I'm having trouble. Could it be that there's an architecture mismatch?

I'm running

python validation.py --seed 2023\
	-vp ./results/vis_sam_med2d \
	-cp ./ckpt/sam_med2d.pth \
	-tdp ./data/validation_test1 -nc 10 \
	--image_size 256 -mt vit_b --dim 2 --save_name ./results/sam_med2d.py --ft2d

as in the infer_med2d.sh file, just with:

  • the ckpt folder having been created myself, containing sam-med2d_b.pth as downloaded from here, as given in the readme.
  • -cp changed to ./ckpt/sam-med2d_b.pth to reflect the name from the download above
  • -tdp changed to ./data/validation to reflect the directory name in the current github repo.

These changes result in the command
python validation.py --seed 2023
-vp ./results/vis_sam_med2d
-cp ./ckpt/sam-med2d_b.pth
-tdp ./data/validation -nc 10
--image_size 256 -mt vit_b --dim 2 --save_name ./results/sam_med2d.py --ft2d

Then, running, I get
set seed as 2023
get 1 datasets
device: cuda
*******interpolate
Traceback (most recent call last):
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 111, in _build_sam
sam.load_state_dict(state_dict['model'])
File "/home/[user]/anaconda3/envs/sam-med/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Sam:
Unexpected key(s) in state_dict: "image_encoder.blocks.0.Adapter.norm.weight", [... lots of layers ...] "image_encoder.blocks.11.Adapter.spatial.2.weight".

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "[my path]//SAM-Med3D/validation.py", line 327, in
sam_model_tune = sam_model_registryargs.model_type.to(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 38, in build_sam_vit_b
return _build_sam(
^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 116, in _build_sam
new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 128, in load_from
pos_embed = new_state_dict['image_encoder.pos_embed']
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'image_encoder.pos_embed'

with text in brackets being my modifications to the message. The error arises in the try/except block in the _build_sam function of build_sam.py:

with open(checkpoint, "rb") as f:
            state_dict = torch.load(f)
        try:
            if 'model' in state_dict.keys():
                sam.load_state_dict(state_dict['model'])
            else:
                sam.load_state_dict(state_dict)
        except:
            print('*******interpolate')
            new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size)   
            sam.load_state_dict(new_state_dict)
        print(f"*******load {checkpoint}")

Thus the loading attempt from the try section didn't work, and the except couldn't fix it. As I understand it, the weights from sam-med2d_b.pth are trying to be loaded into the model sam defined early in the _build_sam function. But this must have the base SAM architecture, and thus couldn't include the adapter layers in SAM-Med2D, which I expect would inevitably lead to the mismatch in architectures I read from the terminal output. Furthermore, the model-type flag must be the name of an original SAM model since it gets passed through the dictionary sam_model_registry. Am I understanding the error correctly? If so, could you possibly share the code you used to test SAM-Med2D on the 3D data?

请问空切片是为什么呀

set seed as 2023
get 2 datasets
device: cuda
0it [00:00, ?it/s]
Save to union_out_dice.py
/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py:3432: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/usr/local/lib/python3.10/dist-packages/numpy/core/_methods.py:190: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Done

blurry segmentation and low res masks

Hi, very exciting work!! I tried to use this model on my 3D microscopy image dataset, with the shape of (128,128,128). However, when I visualized the results I saw that the predictions were often very blurry. I realized that in the code you downsampled the mask size from (128, 128, 128) to (32, 32, 32) and then upsampled the predictions via interpolation. Am I missing something or is this essentially processing the image with resolution (32, 32, 32)? And if yes are there some ways I can maintain the resolution of the prediction as my input (128, 128, 128)? Thank you!

For reference, these are some of the images I had, where left is the input and right is the prediction.
image
image

About Datasets

Have you used the AISD dataset and ISLES 2022 dataset for training? Or in other words, what dataset does you use in respect to Ischemic Stroke Lesion? Thank you very much if you can give me a fast reply.

Restore to original size

Hi again,

Here's a question that you resampled the images and masks to conduct the experiments.
I've trained the model with the 3D pipeline. But I want to resampled to original size, and evaluate then on that resolution.
Currently I'm using SimpleITK resampled function and keep them same origin and direction. The resampled seg mask became all 0s might be due to dilution. Any other better way to do this?

window_unpartition3D output shape goes wrong in case window_size is tuple

I have edited the code on my local such that window_size is tuple instead of int. since the input size of shape (Dp, Hp, Wp). the output shape also should be of the same. The following line should be re-written as

x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hp, Wp, Dp, -1)

x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hp, Wp, Dp, -1)

to

x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Dp, Hp, Wp, -1)

Passing Multiple Points

Hi there!

I tried implementing inference using multiple point prompts but passing them all at once instead of iteratively passing one point and the previous mask. However, the performance degraded with increasing numbers of points. I was wondering if the model actually supports this behaviour? I assumed so only because the shape of the points input is [1,1,3], where I assumed the first dimension was the batch number, and second dimension was the point number; is that maybe wrong? Thanks!

Is SAM-Med3D still being updated?

Thank for your great job!Will the Image Encoder of SAM-Med3D be replaced by VIT-Huge and ViT-Large for training? Will more medical data be added to the training?

Handling Images without Tumors for Model Training and Testing

Hi,

Wonderful job! But when I ran your project, I have a question. My data includes some images that do not contain tumors; therefore, their masks are entirely zero. For these images, is the code unable to generate prompt points during training and testing?

Poor results obtained for evaluation on Medical Segmentation Decathlon Task07_Pancreas dataset

Hello,

Thank you for your interesting work.
I'm following the instructions you provided for evaluating your checkpoint on the Medical Segmentation Decathlon Task07_Pancreas dataset.
However, the mean Dice I obtained for the pancreas and tumor segmentation are 35.8 and 35.6 respectively, which is quite low.
Are these results within the expected range? Do you have any suggestions on how to use your provided checkpoint and code to get better results?

Thank you.

Inquiry Regarding Usage of SAM-Med3D Pretrained 3D Image Encoder for CT/MRI Datasets

Dear Author,

I hope this message finds you well. I am writing to you as a research with a keen interest in your groundbreaking work on the SAM-Med3D model. Your research has garnered significant attention in our community, and we are particularly impressed with the capabilities of your pretrained 3D image encoder.

My team and I are currently working on a project that involves analyzing CT and MRI datasets. We believe that the SAM-Med3D model could greatly enhance our research, specifically in terms of feature extraction from these 3D medical images. However, we are uncertain about the best approach to utilize your pretrained model for our specific datasets.

Could you kindly provide some guidance or instructions on how to effectively use the SAM-Med3D pretrained 3D image encoder with our own CT/MRI datasets? Any insights into preprocessing steps, model integration, or optimal usage practices would be immensely helpful.

We are eager to explore the potential of SAM-Med3D in our research and look forward to potentially contributing to the field with the insights gained from its application.

Thank you very much for your time and consideration. We understand the demands on your schedule and appreciate any assistance you can provide.

[Feature Release] inference for different resolutions

Hi, all

Inference on full-volume image is now supported in infer.sh and inference.py.

It means that SAM-Med3D now support different resolutions (e.g. 512 x 512 x 200, 256 x 256 x 500).

Firstly, we highly recommend you to resample the image into 1.5mm spacing for best segmentation experience.
For targets bigger than 128 x 128x 128, you can enable sliding-window inference in infer.sh.

Hope this new feature helpful to you.

如何加prompt

您好,根据您的步骤我已经可以用训练好的模型和自己的数据进行finetune了,但是我有个问题,如果我想加入提示的话应该怎么做,因为根据您的readme,没有涉及到添加prompt来训练的相关指令

Fine-tuning the pre-trained ViT encoder

How did you conduct transfer experiments on UNETR? Did you replace all the encoders (encoder1-encoder4)? When fine-tuning the pre-trained ViT encoder in SAM-Med3D, is the size of the input image still 128128128?

请问预训练模型使用的数据集有哪些

作者您好:
请问您在仓库中提供的4个预训练模型使用的数据集分别是那些。尤其是后三个,我想在部分公开数据集上跑五折实验看一下效果,但是担心这样做存在数据泄露导致结果不准确,所以恳请您答复一下。

祝好~

Loss value does not decrease

When I am training, the loss value fluctuates within a range and does not decrease after many epochs.The following error code was also output during training:
UserWarning: Detected call of lr_scheduler.step() before optimizer.step(). In PyTorch 1.1.0 and later, you should call them in the opposite order: optimizer.step() before lr_scheduler.step(). Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of lr_scheduler.step() before optimizer.step().

修复更改checkpiont参数无法读取的问题

我发现在更改checkpoint的路径后并不能读取到pth文件,发现源代码里

self.init_checkpoint(join(self.args.work_dir, self.args.task_name, 'sam_model_latest.pth'))

读取到的是工作目录里面最新的pth,于是作了以下更改:

self.init_checkpoint(self.args.checkpoint)

How to prepare SMA-Med3D label data?

How do you provide labeled 3D data(3D point)? Could you please provide more examples? When I opened the example on the GitHub page, the labeled data appeared black. And How can I join in the discussion group? the QR code in readme has lose efficacy.

Error: When I tried to train the liver tumor

When I tried to train the liver tumor, I added it to the datapath

image

But when I run python train.py this error is displayed why?
the end file of nifti data in the path does not exist, why add nii.gz at the end of file name

RuntimeError: Exception thrown in SimpleITK ImageFileReader_Execute: D:\a\1\sitk\Code\IO\src\sitkImageReaderBase.cxx:97:
sitk::ERROR: The file "data/train/Liver_Tumor\imagesTr\segmentation-3.nii.zip.nii.gz" does not exist.

Model testing scheme

Thanks to the author for sharing. I have a question about image testing:
Now the input of the image is (128,128,128). In the actual test process of the image, for example, to test the lung CT image of (512,512,368) and segment the lung parenchyma part, I use two schemes:

Scheme 1: take the first positive sample point as the center, extract the block (128,128,128) for testing, and finally backfill the result back. This scheme: the mask effect calculated by the algorithm is OK, but because of the large area of the lung parenchyma, the edge of the obvious cut block will appear after backfilling, as shown in the following figure.
issueFigure

Scheme 2: the image is resize to (128,128,128), unfortunately, the segmentation effect is poor after resize the image

I would like to ask you two questions:
First, why does the segmentation effect of the algorithm become much worse after resize the image?
Second, in the above situation, what scheme do you suggest to use in the actual application process to better play the effect of the algorithm?

[Hint] Inference without mask

Hi 👋, many people have asked how to perform inference on unlabeled images (without mask), including how to generate prompts and more. We're here to provide a unified response 📢:

This repo provides experimental code, so for now, it only offers validation functionality (testing on datasets with annotations) 🧪.

However, we do plan to provide inference code for unlabeled images and a 3D version demo 🌐. But due to the workload and other time conflicts, the official release will have to wait a bit longer ⏳. Thank you for your patience and understanding! 🙏

about the image size

Hi, thanks again!

Gotta clarify the image_size argument in dataset. If I'm using 512x512xN 3D image and I set this as 128. Will it resize image to 128x128xN or just crop 128x128 region from image?

请问为什么碰到了零除错误

Traceback (most recent call last):
File "/content/SAM-Med3D/train.py", line 507, in
main()
File "/content/SAM-Med3D/train.py", line 466, in main
trainer.train()
File "/content/SAM-Med3D/train.py", line 361, in train
epoch_loss, epoch_iou, epoch_dice, pred_list = self.train_epoch(epoch, num_clicks)
File "/content/SAM-Med3D/train.py", line 335, in train_epoch
epoch_loss /= step
ZeroDivisionError: float division by zero请问这个到底为什么呀

patch-based inference approach for 3D volumes

Hi,

Thanks for sharing your work with the community. Excellent work!

I have successfully trained SAM-Med3D for the brain tumor dataset, and now I want to evaluate the test set. In the paper, it is mentioned that SAM-Med3D operates using a patch-based inference approach. However, in the repository, I needed help finding the script or suitable function you used for 3D patch-based inference. I would appreciate it if you could share the script for this or let me know if I missed something here. :)

Kind regards,
Himashi

[Hint] Data Pre-processing Details

Dear team,
I hope this message finds you well. I recently read the SAM-Med3D paper and came across an intriguing detail in Section 3.2. It was mentioned that cases with a physical size below 1 square cm or with any single dimension shorter than 1.5 cm were removed to enhance the visibility of target masks. I found this approach quite interesting and was wondering if there are any plans to release the code for processing the data?
Thank you for your time and consideration.
Best regards,
You Li

Prompt

hello,我通过阅读论文了解到了这个项目,真的非常令人震撼。
但是我想请教一下,如何实现想论文中的Prompt功能呢,我发现目前项目是通过随机点生成的Prompt,并且通过自身迭代不断优化结果。
并且我也阅读了#31的issues,不过没有发现有什么解决办法,不知道作者有没有这方面的教程之类的,或者与其他人可以交流一下Prompt板块。
谢谢!

UnboundLocalError: local variable 'step' referenced before assignment

File "D:\project\SAM-Med3D-main\SAM-Med3D-main\train.py", line 338, in train_epoch error message is displayed
epoch_loss /= step
UnboundLocalError: local variable 'step' referenced before assignment
Step is defined in the for loop, but epoch_loss/=step is on the outside of the for loop

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.