uni-medical / sam-med3d Goto Github PK
View Code? Open in Web Editor NEWSAM-Med3D: An Efficient General-purpose Promptable Segmentation Model for 3D Volumetric Medical Image
License: Apache License 2.0
SAM-Med3D: An Efficient General-purpose Promptable Segmentation Model for 3D Volumetric Medical Image
License: Apache License 2.0
I'm working on a project where I don't have ground truth during the inference phase. How to Make Predictions When Ground Truth (GT) is Unavailable During Inference?
Thanks for your help!
请问训练完了怎么进行预测的呀,预测文件是validation.py吗?
Hi, thanks for sharing this work!
I'd like to know if you can provide the requirements.txt file for building environment?
I hope you're doing well. I recently read your paper and found it to be very informative and valuable.
I noticed that the validation set list was not provided in the paper or the supplementary materials. Would it be possible for you to share the list of samples used in the validation set? Knowing this information helps me conduct a comparative experiment.
Thank you for your time.
之前一直被dice为0这个bug所困扰,所以下定决心给它修复了,同时在修bug的时候我发现,在打印每个epoch的平均loss值时,step的个数少1,导致每一轮打印出来的loss值偏高。不知道我找到的这个bug是不是正确的,还请作者团队看一下
我还不是很擅长使用git,我就把修改的代码赋在下面了
def train_epoch(self, epoch, num_clicks):
epoch_loss = 0
epoch_iou = 0
epoch_dice = 0
self.model.train()
if self.args.multi_gpu:
sam_model = self.model.module
else:
sam_model = self.model
self.args.rank = -1
if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
tbar = tqdm(self.dataloaders)
else:
tbar = self.dataloaders
self.optimizer.zero_grad()
step_loss = 0
epoch_dice = 0 #change
for step, (image3D, gt3D) in enumerate(tbar):
my_context = self.model.no_sync if self.args.rank != -1 and step % self.args.accumulation_steps != 0 else nullcontext
with my_context():
image3D = self.norm_transform(image3D.squeeze(dim=1)) # (N, C, W, H, D)
image3D = image3D.unsqueeze(dim=1)
image3D = image3D.to(device)
gt3D = gt3D.to(device).type(torch.long)
with amp.autocast():
image_embedding = sam_model.image_encoder(image3D)
self.click_points = []
self.click_labels = []
pred_list = []
prev_masks, loss = self.interaction(sam_model, image_embedding, gt3D, num_clicks=11)
epoch_loss += loss.item()
epoch_dice += self.get_dice_score(prev_masks,gt3D) #change
cur_loss = loss.item()
loss /= self.args.accumulation_steps
self.scaler.scale(loss).backward()
if step % self.args.accumulation_steps == 0 and step != 0:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
print_loss = step_loss / self.args.accumulation_steps
step_loss = 0
print_dice = self.get_dice_score(prev_masks, gt3D)
else:
step_loss += cur_loss
if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
if step % self.args.accumulation_steps == 0 and step != 0:
print(f'Epoch: {epoch}, Step: {step}, Loss: {print_loss}, Dice: {print_dice}')
if print_dice > self.step_best_dice:
self.step_best_dice = print_dice
if print_dice > 0.9:
self.save_checkpoint(
epoch,
sam_model.state_dict(),
describe=f'{epoch}_step_dice:{print_dice}_best'
)
if print_loss < self.step_best_loss:
self.step_best_loss = print_loss
epoch_loss /= step+1 #change
epoch_dice /= step+1 #change
return epoch_loss, epoch_iou, epoch_dice, pred_list
Traceback (most recent call last):
File "C:\Users\pl\Desktop\FILE\SAM-Med3D-main\utils\prepare_data_from_nnUNet.py", line 70, in
for idx, cls_name in meta_info[""].items():
KeyError: ''
HepaticVessel {'0': 'CT'}
num_classes: 2 {'0': 'background', '1': 'Vessel', '2': 'Tumour'}
我试了好几个数据集都不行,dataset.json里也没有空字符啊,很奇怪
First, thank you very much for sharing this amazing project and source code. Really helpful for my study !!
Sorry but i have a question about BraTS2021 dataset preprocessing.
Can you share the details about preparing BraTS_Val* ??
When i run prepare_data_from_nnUNet.py for my BraTS2021 dataset, every 'background' class data (whose volume is all 0) is skipped so that doesn`t save in the dataset folder.
So i removed the code here in prepare_data_from_nnUNet.py
[ if(volume<10):
print("skip", target_img_path)
continue
]
and run train.py. However it has some problems (suddenly the loss is nan, but Dice score is not 0.)
(I check the gt3D consisted of all 0 values, and not sure but guess this is cause of 'background' data)
So i want to know how to construct BraTS 21 dataset.
Thank you again and hope to get an answer !
(If it is my personal issue, sorry in advance for asking bad question ;)) )
Can`t use 'pip install -r requirements.txt'
the reason is that,
in line 170, 171, there are wrong argument '-e'.
Please to delete them !
Thank you again :))
Hello,
I tried to use multiple GPUs by setting --multi_gpu, however, only GPU 0 was used during training according to the outputs of nvidia-smi command.
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.68 GiB (GPU 1; 31.75 GiB total capacity; 27.51 GiB already allocated; 1.65 GiB free; 28.90 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
How can I solve this problem if i only have V100 GPU ? thanks !
Hi,
In the validation.py file, I noticed that the data processing includes the infer_transform operation, which sets mask_name='label'. I would like to ask how an image without a mask label is handled and input into the model for prediction, and whether this handling affects the performance of the segmentation results.
I'm not sure if my understanding of this part of the code is correct. I would appreciate your response and clarification. Thank you.
Hi, thanks for sharing!
I'm very interested in your work, I found that the checkpoint of SAM-Med3D of Google Drive jumps to the project homepage, could you provide a new link, thanks a lot.
I'm trying to use the given code for 3D inference with SAM-Med 2D, but I'm having trouble. Could it be that there's an architecture mismatch?
I'm running
python validation.py --seed 2023\
-vp ./results/vis_sam_med2d \
-cp ./ckpt/sam_med2d.pth \
-tdp ./data/validation_test1 -nc 10 \
--image_size 256 -mt vit_b --dim 2 --save_name ./results/sam_med2d.py --ft2d
as in the infer_med2d.sh file, just with:
-tdp
changed to ./data/validation to reflect the directory name in the current github repo.These changes result in the command
python validation.py --seed 2023
-vp ./results/vis_sam_med2d
-cp ./ckpt/sam-med2d_b.pth
-tdp ./data/validation -nc 10
--image_size 256 -mt vit_b --dim 2 --save_name ./results/sam_med2d.py --ft2d
Then, running, I get
set seed as 2023
get 1 datasets
device: cuda
*******interpolate
Traceback (most recent call last):
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 111, in _build_sam
sam.load_state_dict(state_dict['model'])
File "/home/[user]/anaconda3/envs/sam-med/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Sam:
Unexpected key(s) in state_dict: "image_encoder.blocks.0.Adapter.norm.weight", [... lots of layers ...] "image_encoder.blocks.11.Adapter.spatial.2.weight".
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "[my path]//SAM-Med3D/validation.py", line 327, in
sam_model_tune = sam_model_registryargs.model_type.to(device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 38, in build_sam_vit_b
return _build_sam(
^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 116, in _build_sam
new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "[my path]/SAM-Med3D/segment_anything/build_sam.py", line 128, in load_from
pos_embed = new_state_dict['image_encoder.pos_embed']
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'image_encoder.pos_embed'
with text in brackets being my modifications to the message. The error arises in the try/except block in the _build_sam function of build_sam.py:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
try:
if 'model' in state_dict.keys():
sam.load_state_dict(state_dict['model'])
else:
sam.load_state_dict(state_dict)
except:
print('*******interpolate')
new_state_dict = load_from(sam, state_dict, image_size, vit_patch_size)
sam.load_state_dict(new_state_dict)
print(f"*******load {checkpoint}")
Thus the loading attempt from the try section didn't work, and the except couldn't fix it. As I understand it, the weights from sam-med2d_b.pth are trying to be loaded into the model sam defined early in the _build_sam function. But this must have the base SAM architecture, and thus couldn't include the adapter layers in SAM-Med2D, which I expect would inevitably lead to the mismatch in architectures I read from the terminal output. Furthermore, the model-type flag must be the name of an original SAM model since it gets passed through the dictionary sam_model_registry. Am I understanding the error correctly? If so, could you possibly share the code you used to test SAM-Med2D on the 3D data?
set seed as 2023
get 2 datasets
device: cuda
0it [00:00, ?it/s]
Save to union_out_dice.py
/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py:3432: RuntimeWarning: Mean of empty slice.
return _methods._mean(a, axis=axis, dtype=dtype,
/usr/local/lib/python3.10/dist-packages/numpy/core/_methods.py:190: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Done
Hi, very exciting work!! I tried to use this model on my 3D microscopy image dataset, with the shape of (128,128,128). However, when I visualized the results I saw that the predictions were often very blurry. I realized that in the code you downsampled the mask size from (128, 128, 128) to (32, 32, 32) and then upsampled the predictions via interpolation. Am I missing something or is this essentially processing the image with resolution (32, 32, 32)? And if yes are there some ways I can maintain the resolution of the prediction as my input (128, 128, 128)? Thank you!
For reference, these are some of the images I had, where left is the input and right is the prediction.
Have you used the AISD dataset and ISLES 2022 dataset for training? Or in other words, what dataset does you use in respect to Ischemic Stroke Lesion? Thank you very much if you can give me a fast reply.
Hi again,
Here's a question that you resampled the images and masks to conduct the experiments.
I've trained the model with the 3D pipeline. But I want to resampled to original size, and evaluate then on that resolution.
Currently I'm using SimpleITK resampled function and keep them same origin and direction. The resampled seg mask became all 0s might be due to dilution. Any other better way to do this?
I have edited the code on my local such that window_size is tuple instead of int. since the input size of shape (Dp, Hp, Wp). the output shape also should be of the same. The following line should be re-written as
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hp, Wp, Dp, -1)
to
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Dp, Hp, Wp, -1)
Hi,
I would like to ask if the code is easily to be modified for multiclass semantic segmentation?
please advise,
thanks
Hi there!
I tried implementing inference using multiple point prompts but passing them all at once instead of iteratively passing one point and the previous mask. However, the performance degraded with increasing numbers of points. I was wondering if the model actually supports this behaviour? I assumed so only because the shape of the points input is [1,1,3], where I assumed the first dimension was the batch number, and second dimension was the point number; is that maybe wrong? Thanks!
Thank for your great job!Will the Image Encoder of SAM-Med3D be replaced by VIT-Huge and ViT-Large for training? Will more medical data be added to the training?
Hi,
Wonderful job! But when I ran your project, I have a question. My data includes some images that do not contain tumors; therefore, their masks are entirely zero. For these images, is the code unable to generate prompt points during training and testing?
Hello,
Thank you for your interesting work.
I'm following the instructions you provided for evaluating your checkpoint on the Medical Segmentation Decathlon Task07_Pancreas dataset.
However, the mean Dice I obtained for the pancreas and tumor segmentation are 35.8 and 35.6 respectively, which is quite low.
Are these results within the expected range? Do you have any suggestions on how to use your provided checkpoint and code to get better results?
Thank you.
Dear Author,
I hope this message finds you well. I am writing to you as a research with a keen interest in your groundbreaking work on the SAM-Med3D model. Your research has garnered significant attention in our community, and we are particularly impressed with the capabilities of your pretrained 3D image encoder.
My team and I are currently working on a project that involves analyzing CT and MRI datasets. We believe that the SAM-Med3D model could greatly enhance our research, specifically in terms of feature extraction from these 3D medical images. However, we are uncertain about the best approach to utilize your pretrained model for our specific datasets.
Could you kindly provide some guidance or instructions on how to effectively use the SAM-Med3D pretrained 3D image encoder with our own CT/MRI datasets? Any insights into preprocessing steps, model integration, or optimal usage practices would be immensely helpful.
We are eager to explore the potential of SAM-Med3D in our research and look forward to potentially contributing to the field with the insights gained from its application.
Thank you very much for your time and consideration. We understand the demands on your schedule and appreciate any assistance you can provide.
Hi, all
Inference on full-volume image is now supported in infer.sh
and inference.py
.
It means that SAM-Med3D now support different resolutions (e.g. 512 x 512 x 200
, 256 x 256 x 500
).
Firstly, we highly recommend you to resample the image into 1.5mm spacing for best segmentation experience.
For targets bigger than 128 x 128x 128
, you can enable sliding-window inference in infer.sh
.
Hope this new feature helpful to you.
您好,根据您的步骤我已经可以用训练好的模型和自己的数据进行finetune了,但是我有个问题,如果我想加入提示的话应该怎么做,因为根据您的readme,没有涉及到添加prompt来训练的相关指令
在train.py代码中,我发现epoch_dice一直没被更新,只是被赋了一个0值
How did you conduct transfer experiments on UNETR? Did you replace all the encoders (encoder1-encoder4)? When fine-tuning the pre-trained ViT encoder in SAM-Med3D, is the size of the input image still 128128128?
作者您好:
请问您在仓库中提供的4个预训练模型使用的数据集分别是那些。尤其是后三个,我想在部分公开数据集上跑五折实验看一下效果,但是担心这样做存在数据泄露导致结果不准确,所以恳请您答复一下。
祝好~
Can you provide more information about the datasets used for SAM-Med3D-brain and SAM-Med3D-organ training?
Additionally, Can you highlight the differences between SAM-Med3D and others?
Can upload coding about cropping from the center of the first click to solve the problem getting musk without labeling
When I am training, the loss value fluctuates within a range and does not decrease after many epochs.The following error code was also output during training:
UserWarning: Detected call of lr_scheduler.step()
before optimizer.step()
. In PyTorch 1.1.0 and later, you should call them in the opposite order: optimizer.step()
before lr_scheduler.step()
. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of lr_scheduler.step()
before optimizer.step()
.
我发现在更改checkpoint的路径后并不能读取到pth文件,发现源代码里
self.init_checkpoint(join(self.args.work_dir, self.args.task_name, 'sam_model_latest.pth'))
读取到的是工作目录里面最新的pth,于是作了以下更改:
self.init_checkpoint(self.args.checkpoint)
How do you provide labeled 3D data(3D point)? Could you please provide more examples? When I opened the example on the GitHub page, the labeled data appeared black. And How can I join in the discussion group? the QR code in readme has lose efficacy.
When I tried to train the liver tumor, I added it to the datapath
But when I run python train.py this error is displayed why?
the end file of nifti data in the path does not exist, why add nii.gz at the end of file name
RuntimeError: Exception thrown in SimpleITK ImageFileReader_Execute: D:\a\1\sitk\Code\IO\src\sitkImageReaderBase.cxx:97:
sitk::ERROR: The file "data/train/Liver_Tumor\imagesTr\segmentation-3.nii.zip.nii.gz" does not exist.
why when train liver tumor:
DICE=0
Thanks to the author for sharing. I have a question about image testing:
Now the input of the image is (128,128,128). In the actual test process of the image, for example, to test the lung CT image of (512,512,368) and segment the lung parenchyma part, I use two schemes:
Scheme 1: take the first positive sample point as the center, extract the block (128,128,128) for testing, and finally backfill the result back. This scheme: the mask effect calculated by the algorithm is OK, but because of the large area of the lung parenchyma, the edge of the obvious cut block will appear after backfilling, as shown in the following figure.
Scheme 2: the image is resize to (128,128,128), unfortunately, the segmentation effect is poor after resize the image
I would like to ask you two questions:
First, why does the segmentation effect of the algorithm become much worse after resize the image?
Second, in the above situation, what scheme do you suggest to use in the actual application process to better play the effect of the algorithm?
Hello, I tried to predict labels on my custom dataset and Dice score looks normal however resulted mask cannot be open in Slicer and it is not showing some mask. Can you help me ?
How can I restore prediction results from 128128128 back to the original input size? Any guidance or documentation would be helpful.
Thanks,
When I ran train.py
AttributeError: 'Namespace' object has no attribute 'rank'
Hi 👋, many people have asked how to perform inference on unlabeled images (without mask), including how to generate prompts and more. We're here to provide a unified response 📢:
This repo provides experimental code, so for now, it only offers validation functionality (testing on datasets with annotations) 🧪.
However, we do plan to provide inference code for unlabeled images and a 3D version demo 🌐. But due to the workload and other time conflicts, the official release will have to wait a bit longer ⏳. Thank you for your patience and understanding! 🙏
Hi, thanks again!
Gotta clarify the image_size argument in dataset. If I'm using 512x512xN 3D image and I set this as 128. Will it resize image to 128x128xN or just crop 128x128 region from image?
Traceback (most recent call last):
File "/content/SAM-Med3D/train.py", line 507, in
main()
File "/content/SAM-Med3D/train.py", line 466, in main
trainer.train()
File "/content/SAM-Med3D/train.py", line 361, in train
epoch_loss, epoch_iou, epoch_dice, pred_list = self.train_epoch(epoch, num_clicks)
File "/content/SAM-Med3D/train.py", line 335, in train_epoch
epoch_loss /= step
ZeroDivisionError: float division by zero请问这个到底为什么呀
Hi,
Thanks for sharing your work with the community. Excellent work!
I have successfully trained SAM-Med3D for the brain tumor dataset, and now I want to evaluate the test set. In the paper, it is mentioned that SAM-Med3D operates using a patch-based inference approach. However, in the repository, I needed help finding the script or suitable function you used for 3D patch-based inference. I would appreciate it if you could share the script for this or let me know if I missed something here. :)
Kind regards,
Himashi
Dear team,
I hope this message finds you well. I recently read the SAM-Med3D paper and came across an intriguing detail in Section 3.2. It was mentioned that cases with a physical size below 1 square cm or with any single dimension shorter than 1.5 cm were removed to enhance the visibility of target masks. I found this approach quite interesting and was wondering if there are any plans to release the code for processing the data?
Thank you for your time and consideration.
Best regards,
You Li
hello,我通过阅读论文了解到了这个项目,真的非常令人震撼。
但是我想请教一下,如何实现想论文中的Prompt功能呢,我发现目前项目是通过随机点生成的Prompt,并且通过自身迭代不断优化结果。
并且我也阅读了#31的issues,不过没有发现有什么解决办法,不知道作者有没有这方面的教程之类的,或者与其他人可以交流一下Prompt板块。
谢谢!
Great job!When the dataset will be release :)
File "D:\project\SAM-Med3D-main\SAM-Med3D-main\train.py", line 338, in train_epoch error message is displayed
epoch_loss /= step
UnboundLocalError: local variable 'step' referenced before assignment
Step is defined in the for loop, but epoch_loss/=step is on the outside of the for loop
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.