Code Monkey home page Code Monkey logo

mmp's People

Contributors

andrewsong90 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mmp's Issues

visualization

Thank you for your wonderful work!

I encountered an issue while running mmp_visualization.ipynb. Although I successfully ran MMP and obtained the results, the notebook requires reading an h5_feats_fpath = f'<path/to/features>/{slide_id}.h5' to simultaneously access the coordinates and features of the WSI. I used CLAM to segment and extract WSI features, so I ended up with an .h5 file containing the coordinates and a .pt file containing the features. I tried to merge these two files into a single .h5 file as input, but I received an error indicating that I need to input a three-dimensional tensor. Do you know how I can obtain this three-dimensional tensor? Thank you!

slide_id = 'TCGA-RZ-AB0B-01Z-00-DX1.0DF1A3A6-3030-4988-AC2C-CAA0F2EBAEB2' 
slide_fpath = f'../DATA_DIRECTORY/TCGA-RZ-AB0B-01Z-00-DX1.0DF1A3A6-3030-4988-AC2C-CAA0F2EBAEB2.svs' 
h5_feats_fpath = f'../output_file.h5'  
wsi = openslide.open_slide(slide_fpath)  
h5 = h5py.File(h5_feats_fpath, 'r') 

coords = h5['coords'][:]  
feats = torch.Tensor(h5['features'][:])  
custom_downsample = 2  
patch_size = h5['coords'].attrs['patch_size'] * custom_downsample  

with torch.inference_mode():
    out, qqs = panther_encoder.representation(feats).values()  
    tokenizer = PrototypeTokenizer(p=16, out_type='allcat')  
    mus, pis, sigmas = tokenizer.forward(out)  
    mus = mus[0].detach().cpu().numpy()  
    qq = qqs[0,:,:,0].cpu().numpy()  
    global_cluster_labels = qq.argmax(axis=1)  

cat_map = visualize_categorical_heatmap(
    wsi,
    coords, 
    global_cluster_labels, 
    label2color_dict=color_map,
    vis_level=wsi.get_best_level_for_downsample(128),
    patch_size=(patch_size, patch_size),
    alpha=0.4,
)  

display(cat_map.resize((cat_map.width//4, cat_map.height//4)))  
display(get_mixture_plot(mus, colors=list(color_map_hex.values())))  

ValueError Traceback (most recent call last)
Cell In[4], line 15
13
14 with torch.inference_mode():
---> 15 out, qqs = panther_encoder.representation(feats).values()
16 tokenizer = PrototypeTokenizer(p=16, out_type='allcat')
17 mus, pis, sigmas = tokenizer.forward(out)

File ~/documents/code_notes/visualization/MMP/src/visualization/../mil_models/model_PANTHER.py:29, in PANTHER.representation(self, x)
28 def representation(self, x):
---> 29 out, qqs = self.panther(x)
30 return {'repr': out, 'qq': qqs}

File ~/anaconda3/envs/mmp/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/anaconda3/envs/mmp/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/documents/code_notes/visualization/MMP/src/visualization/../mil_models/PANTHER/layers.py:49, in PANTHERBase.forward(self, S, mask)
48 def forward(self, S, mask=None):
---> 49 B, N_max, d = S.shape
51 if mask is None:
52 mask = torch.ones(B, N_max).to(S)
ValueError: not enough values to unpack (expected 3, got 2)

Regarding the csv files

Thank you for your work. May I ask how did you obtain these csv files? The data I downloaded from TCGA seems to be different, and I couldn't find some labels' data in the csv files.

About the results

Thank you for your contribution. Regarding the c-index value of the experiment, it does not reach the result in your paper.Use resnet 50, x20, 256×256patches in CLAM for feature extraction. Multimodal Fusion : model_tuple='PANTHER,default'、model_mm_type='coattn'. Data Set:TCGA-BRCA
Here are my results:

epoch 50 : 0.649±0.046
epoch 100 : 0.655±0.046
epoch 200 :0.651±0.050

The c_index_test result of the five-fold cross validation corresponding to epoch100 is:

k=0  0.572164948453608
k=1  0.645803698435277
k=2  0.663324538258575
k=3  0.691629955947137
k=4  0.703731911652704

Hope to get your reply.

some questions

Thank you for your excellent work! I have some questions regarding this model that I would like to discuss with you:

  1. Regarding the all_dump.h5 file: Based on the code, this file contains the time, censorship, and corresponding attention scores for both the training and test sets. After reading it, I found that it is actually a nested dictionary type. When I extracted the data, I discovered that the patient IDs in the training set did not match the corresponding time and censorship values, whereas the test set matched correctly. (That is, the patient IDs extracted from this file were assigned incorrect survival times and censorship values for some reason. In fact, I encountered the same issue when using the previous version of PORPOISE. To obtain correctly paired risk scores, I had to put all the data into the test set to get correctly matched data.)

  2. This question is related to the previous one. I want to know how you view the evaluation of a prognostic model's quality using the c-index and Kaplan-Meier curves. In my opinion, the c-index reflects the model's ability to predict survival outcomes, while the K-M curve shows the trend of patient deaths/censorship over time given the same preset survival probability. This is indeed the case; I replicated MMP on the TCGA-UVM dataset and found that patients with higher risk scores often had death outcomes, while those with lower scores were either alive or censored. However, due to differences in the start and length of follow-up times, some patients had short follow-up periods and were ultimately censored, causing the model to assign them lower risk scores. Conversely, some patients had long follow-up periods and eventually died, leading the model to assign them relatively high risk scores, resulting in very high p-values for the K-M curve (p-values for each fold exceeding 1).

  3. Regarding mmp_visualization.ipynb: I successfully visualized the heatmap by modifying the code to out, qqs = panther_encoder.representation(feats.unsqueeze(dim=0)).values(), but the results differ significantly from those presented in the paper. I found that the overlayed patches on the heatmap are semi-transparent, making it difficult to distinguish whether a region is genuinely overlayed or just tissue if the tissue color is similar to the patches. Additionally, if the default parameters are used for visualization, the patch size changes to 448*448 instead of the original 224*224.

  4. Also in mmp_visualization.ipynb: The line path2omic = cross_attn_path2omic.loc[omic].sort_values(ascending=False) throws an error if the by parameter is missing, but I need to sort each column's data separately. How did you solve this issue?

About the visualization

Missing parameters:

def create_embedding_model(args, mode='classification', config_dir='./configs'):
    """
    Create classification or survival models
    """
    config_path = os.path.join(config_dir, args.model_histo_config, 'config.json')
    assert os.path.exists(config_path), f"Config path {config_path} doesn't exist!"
    model_type = args.model_histo_type
    update_dict = {'in_dim': args.in_dim,
                   'out_size': args.n_proto,
                   'load_proto': args.load_proto,
                   'fix_proto': args.fix_proto,
                   'proto_path': args.proto_path}

The "args.model_histo_config、args.model_histo_type" used by this function is undefined.

def get_panther_encoder(in_dim, p, proto_path, config_dir='../'):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type', type=str, default='PANTHER')
    parser.add_argument('--proto_model_type', type=str, default='PANTHER')
    parser.add_argument('--model_config', type=str, default='PANTHER_default')
    parser.add_argument('--in_dim', type=int, default=in_dim)
    parser.add_argument('--embed_dim', type=int, default=64)
    parser.add_argument('--n_proto', type=int, default=16)
    parser.add_argument('--n_classes', type=str, default=2)
    parser.add_argument('--out_size', type=int, default=p)
    parser.add_argument('--em_iter', type=int, default=1)
    parser.add_argument('--tau', type=float, default=1)
    parser.add_argument('--out_type', type=str, default='allcat')
    parser.add_argument('--n_fc_layers', type=int, default=0)
    parser.add_argument('--load_proto', type=int, default=1)
    parser.add_argument('--ot_eps', type=int, default=1)
    args = parser.parse_known_args()[0]
    args.fix_proto = 1
    args.proto_path = proto_path
    model = create_embedding_model(args, config_dir=config_dir)
    model.eval()
    return model

Could you provide the complete function definition for this part? Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.