Code Monkey home page Code Monkey logo

cxrmate's Introduction

CXRMate: Leveraging Longitudinal Data and a Semantic Similarity Reward for Chest X-Ray Report Generation

Paper (arXiv): https://arxiv.org/abs/2307.09758

@misc{nicolson2023longitudinal,
      title={Longitudinal Data and a Semantic Similarity Reward for Chest X-Ray Report Generation}, 
      author={Aaron Nicolson and Jason Dowling and Bevan Koopman},
      year={2023},
      eprint={2307.09758},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

CXRMate is a longitudinal, multi-image CXR report generation encoder-to-decoder model that conditions the report generation process on the report from the previous patient's study if available. The CXRMate checkpoint trained on MIMIC-CXR is available on the Hugging Face Hub: https://huggingface.co/aehrc/cxrmate.

CXRMate: a longitudinal, multi-image CXR report generator trained with reinforcement learning using the CXR-BERT cosine similarity reward. The findings and impression sections from the reports of the current and previous studies are differentiated by section embeddings and separator tokens. The prompt is the report of the previous study. The decoder is prompted by the report of the previous study. [PMT], [PMT-SEP] [BOS], [SEP], and [EOS] denote the prompt, prompt separator, beginning-of-sentence, separator, and end-of-sentence special tokens, respectively.

Generated reports:

Generated reports for the single-image, multi-image, and longitudinal, multi-image CXR generators (both prompted with the radiologist and the generated reports) are located in the generated_reports directory.

Hugging Face models:

SCST: Self-Critical Sequence Training, TF: Teacher Forcing

Notebook examples:

Notebook examples for the models can be found in the examples directory.

Dataset:

  • The MIMIC-CXR-JPG dataset is available at: https://physionet.org/content/mimic-cxr-jpg/2.0.0/

Installation:

After cloning the repository, install the required packages in a virtual environment. The required packages are located in requirements.txt:

python -m venv --system-site-packages venv
source venv/bin/activate
python -m pip install --upgrade pip
python -m pip install --upgrade -r requirements.txt --no-cache-dir

Test the Hugging Face checkpoints:

The model configurations for each task can be found in its config directory, e.g. config/test_huggingface_longitudinal_gen_prompt_cxr-bert.yaml. To run testing:

dlhpcstarter -t cxrmate_hf -c config/test_huggingface/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --test

See dlhpcstarter==0.1.4 for more options.

Note:

Training:

To train with teacher forcing:

dlhpcstarter -t cxrmate -c config/train/longitudinal_gt_prompt_tf.yaml --stages_module tools.stages --train

The model can then be tested with the --test flag:

dlhpcstarter -t cxrmate -c config/train/longitudinal_gt_prompt_tf.yaml --stages_module tools.stages --test

To then train with Self-Critical Sequence Training (SCST) with the CXR-BERT reward:

  1. Copy the path to the checkpoint from the exp_dir for the configuration above, then paste it in the configuration for SCST as warm_start_ckpt_path, then:
  2. dlhpcstarter -t mimic_cxr -c config/train/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --train
    

Note:

Help/Issues:

If you need help, or if there are any issues, please leave an issue and we will get back to you as soon as possible.

cxrmate's People

Contributors

anicolson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

cxrmate's Issues

Training consultation

When I train with Self-Critical Sequence Training (SCST) with the CXR-BERT reward, I set
devices: 2
mbatch_size: 16
num_workers: 32
but encountered the following error:
'''
(venv) [root@3dc54336e478 home]# dlhpcstarter -t mimic_cxr -c config/train/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --train
Seed set to 0
PTL no. devices: 2.
PTL no. nodes: 1.
/usr/local/lib/python3.8/site-packages/lightning/fabric/connector.py:571: precision=16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Description, Special token, Index
bos_token, [BOS], 1
eos_token, [EOS], 2
unk_token, [UNK], 0
sep_token, [SEP], 3
pad_token, [PAD], 4
cls_token, [BOS], 1
mask_token, [MASK], 5
additional_special_token, [NF], 6
additional_special_token, [NI], 7
additional_special_token, [PMT], 8
additional_special_token, [PMT-SEP], 9
additional_special_token, [NPF], 10
additional_special_token, [NPI], 11
/home/modules/transformers/longitudinal_model/modelling_longitudinal.py:155: UserWarning: The encoder-to-decoder model was not warm-started before applying low-rank approximation.
warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')
trainable params: 147,456 || all params: 80,916,528 || trainable%: 0.1822
/usr/local/lib/python3.8/site-packages/transformers/models/convnext/feature_extraction_convnext.py:28: FutureWarning: The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ConvNextImageProcessor instead.
warnings.warn(
Warm-starting using: /home/experiments/cxrmate/longitudinal_gt_prompt_tf/trial_0/epoch=19-step=78380-val_report_chexbert_f1_macro=0.371041.ckpt.
/usr/local/lib/python3.8/site-packages/dlhpcstarter/utils.py:347: UserWarning: The "last" checkpoint does not exist, starting training from epoch 0.
warnings.warn('The "last" checkpoint does not exist, starting training from epoch 0.')
You are using a CUDA device ('Z100L') that has Tensor Cores. To properly utilize them, you should set torch.set_float32_matmul_precision('medium' | 'high') which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[rank: 0] Seed set to 0
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Seed set to 0
PTL no. devices: 2.
PTL no. nodes: 1.
Description, Special token, Index
bos_token, [BOS], 1
eos_token, [EOS], 2
unk_token, [UNK], 0
sep_token, [SEP], 3
pad_token, [PAD], 4
cls_token, [BOS], 1
mask_token, [MASK], 5
additional_special_token, [NF], 6
additional_special_token, [NI], 7
additional_special_token, [PMT], 8
additional_special_token, [PMT-SEP], 9
additional_special_token, [NPF], 10
additional_special_token, [NPI], 11
/home/modules/transformers/longitudinal_model/modelling_longitudinal.py:155: UserWarning: The encoder-to-decoder model was not warm-started before applying low-rank approximation.
warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')
trainable params: 147,456 || all params: 80,916,528 || trainable%: 0.1822
/usr/local/lib/python3.8/site-packages/transformers/models/convnext/feature_extraction_convnext.py:28: FutureWarning: The class ConvNextFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please use ConvNextImageProcessor instead.
warnings.warn(
Warm-starting using: /home/experiments/cxrmate/longitudinal_gt_prompt_tf/trial_0/epoch=19-step=78380-val_report_chexbert_f1_macro=0.371041.ckpt.
/usr/local/lib/python3.8/site-packages/dlhpcstarter/utils.py:347: UserWarning: The "last" checkpoint does not exist, starting training from epoch 0.
warnings.warn('The "last" checkpoint does not exist, starting training from epoch 0.')
[rank: 1] Seed set to 0
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0711 11:46:15.886372 31375 ProcessGroupNCCL.cpp:686] [Rank 1] ProcessGroupNCCL initialization options:NCCL_ASYNC_ERROR_HANDLING: 1, NCCL_DESYNC_DEBUG: 0, NCCL_ENABLE_TIMING: 0, NCCL_BLOCKING_WAIT: 0, TIMEOUT(ms): 1800000, USE_HIGH_PRIORITY_STREAM: 0, TORCH_DISTRIBUTED_DEBUG: OFF, NCCL_DEBUG: OFF, ID=226348544
WARNING: Logging before InitGoogleLogging() is written to STDERR
I0711 11:46:15.892076 31223 ProcessGroupNCCL.cpp:686] [Rank 0] ProcessGroupNCCL initialization options:NCCL_ASYNC_ERROR_HANDLING: 1, NCCL_DESYNC_DEBUG: 0, NCCL_ENABLE_TIMING: 0, NCCL_BLOCKING_WAIT: 0, TIMEOUT(ms): 1800000, USE_HIGH_PRIORITY_STREAM: 0, TORCH_DISTRIBUTED_DEBUG: OFF, NCCL_DEBUG: OFF, ID=229697888

distributed_backend=nccl
All distributed processes registered. Starting with 2 processes

I0711 11:46:16.570466 31223 ProcessGroupNCCL.cpp:1340] NCCL_DEBUG: N/A
/usr/local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /home/experiments/mimic_cxr/longitudinal_gen_prompt_cxr-bert/trial_0 exists and is not empty.
/home/data/prompt.py:186: UserWarning: The number of examples is not divisible by the world size. Adding extra studies to account for this. This needs to be accounted for outside of the dataset.
warnings.warn('The number of examples is not divisible by the world size. '
Traceback (most recent call last):
File "/usr/local/bin/dlhpcstarter", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/main.py", line 126, in main
submit(args, cmd_line_args, stages_fnc)
File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/main.py", line 21, in submit
stages_fnc(args)
File "/home/tools/stages.py", line 85, in stages
trainer.fit(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
return function(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 96, in _call_setup_hook
_call_lightning_module_hook(trainer, "setup", stage=fn)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/modules/lightning_modules/longitudinal/scst/gen_prompt.py", line 66, in setup
self.train_set = PreviousReportSubset(
File "/home/data/prompt.py", line 73, in init
self.allocate_subjects_to_rank(shuffle_subjects=False)
File "/home/data/prompt.py", line 212, in allocate_subjects_to_rank
assert len(set(self.examples)) == self.df.study_id.nunique() and
AssertionError
I0711 11:46:24.351401 31223 ProcessGroupNCCL.cpp:874] [Rank 0] Destroyed 1communicators on CUDA device 0
/home/data/prompt.py:186: UserWarning: The number of examples is not divisible by the world size. Adding extra studies to account for this. This needs to be accounted for outside of the dataset.
warnings.warn('The number of examples is not divisible by the world size. '
Traceback (most recent call last):
File "/usr/local/bin/dlhpcstarter", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/main.py", line 126, in main
submit(args, cmd_line_args, stages_fnc)
File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/main.py", line 21, in submit
stages_fnc(args)
File "/home/tools/stages.py", line 85, in stages
trainer.fit(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
return function(*args, **kwargs)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 96, in _call_setup_hook
_call_lightning_module_hook(trainer, "setup", stage=fn)
File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 159, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/home/modules/lightning_modules/longitudinal/scst/gen_prompt.py", line 66, in setup
self.train_set = PreviousReportSubset(
File "/home/data/prompt.py", line 73, in init
self.allocate_subjects_to_rank(shuffle_subjects=False)
File "/home/data/prompt.py", line 212, in allocate_subjects_to_rank
assert len(set(self.examples)) == self.df.study_id.nunique() and
AssertionError
I0711 11:46:25.112917 31375 ProcessGroupNCCL.cpp:874] [Rank 1] Destroyed 1communicators on CUDA device 1
'''
I want to ask how you set the parameters during training. I saw that your paper used 4×16GB NVIDIA Tesla P100 GPUs. I used 2×32GB NVIDIA V100 GPUs.And I set devices: 1 mbatch_size: 1 without error, but it is too slow. I look forward to your answer,thank you very much!

Multi-card training problem

Hi!
when i run “dlhpcstarter -t cxrmate -c config/train/longitudinal_gt_prompt_tf.yaml --stages_module tools.stages --train”
But the speed is very slow. How should I set the training parameters and data batches to speed up the training process?
Only 4013MiB / 32768MiB of memory is used single gup
and my gpu card is 32G*2
Thank you very much!

TypeError: forward() got an unexpected keyword argument 'output_attentions'

Hi! Thanks for your contribution.

When I trained config/train/single_tf.yaml: the following error occurred:

Traceback (most recent call last):
  File "/usr/local/bin/dlhpcstarter", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 126, in main
    submit(args, cmd_line_args, stages_fnc)
  File "/usr/local/lib/python3.8/site-packages/dlhpcstarter/__main__.py", line 21, in submit
    stages_fnc(args)
  File "/home/tools/stages.py", line 85, in stages
    trainer.fit(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 141, in run
    self.on_advance_end(data_fetcher)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 295, in on_advance_end
    self.val_loop.run()
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 410, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 640, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/home/modules/lightning_modules/single.py", line 455, in validation_step
    output_ids = self.encoder_decoder.generate(
  File "/usr/local/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/generation/utils.py", line 1597, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  File "/usr/local/lib/python3.8/site-packages/transformers/generation/utils.py", line 523, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'output_attentions'

But I trained other models config/train/multi_tf.yaml without any problem.
I did not find out the reason. I tried switching the transformer version, but it didn't work.

Validation hangs after first validation sample when using deepspeed stage 3 with multi-gpu

Hi, thanks again for this excellent repo. The model trains fine for me on a multi-gpu system using deepspeed stage 2 with or without offload, and I see the expected training/validation time reduction. However, when I train with deepspeed stage 3, the training_step progresses fine, but the model hangs indefinitely immediately after the validation step starts and one validation sample is completed. By adding logging, I was able to determine that the model generates tokens one-by-one for the multiple GPUs simulataneously as expected, but as soon as one GPU is done generating the sequence for its validation sample, all the other GPUs hang immediately afterwards. This makes me suspect the issue has something to do with synchronization across the GPUs, where it will not allow the other GPUs to continue to go through calls to the forward function which generates tokens one-by-one during validation here, since one GPU has already finished that part. However I don't understand why this issue would manifest only during stage 3 and not also during stage 2.

Has anyone else successfully trained this model, including the validation step, using deepspeed stage 3? I've tried adjusting some of the deepspeed_config parameters, but no success yet. Thanks!

[PAD] tokens override text report generation when batch size > 1

Hi, thanks again for this helpful repo. I am implementing this model training code but running into a strange problem. When I set my batch_size to 1 for both TF and SCST, I don't have any problems during training or validation and all the text reports look good. The train step of TF training also works well when batch size >1, but when batch size >1 during validation step, I get text reports that look like this (I am leaving the special tokens in on purpose here):

[BOS][PAD][PAD][PAD][PAD][PAD], catheters and devices: left chest wall pacemaker. lungs: left lung base atelectasis or infiltrate. pleural spaces: left pleural effusion. heart/mediastinum: cardiomegaly. bones/joints: unremarkable.[SEP]left pleural effusion. left lung base atelectasis or infiltrate.[EOS]

These results have the first part of the text report cut off - the beginning should read "Tubes, catheters, and devices". This occurs for roughly half of my samples when batch_size = 2. This also occurs during the training step of SCST (but not TF) if batch size >1, which makes me think it is related to the self.encoder_decoder.generate() function. It looks like it occurs when the prompt lengths of the previous report are different between the 2 samples in the batch, so it adds [PAD] tokens to the shorter one to make them the same length. However these [PAD] tokens do appear to be appropriately masked by the decoder_attention_mask. I don't understand why the [PAD] tokens would override other correct caption words, rather than just have the generated text start after the end of the [PAD] tokens.

Is this a common problem for this model? I see in the paper that the model was trained with a mbatch size of 32 so higher batch sizes must be possible.

As a side question - is there any expected difference in function of the model if the [BOS] token were to come after all those [PAD] tokens in the prompt, rather than before?

Thanks!

Regarding the selection of Transformer decoder

Hi! Thanks for your contribution. It is an excellent piece of work!

I would like to ask why you chose a randomly-initialised Transformer decoder with six layers? Do you have any relevant literature references?I'm very curious about it.

Thank you very much for your time and consideration. I eagerly look forward to your response.

GTPrompt model training problem

Hi! Thanks for your contribution. It is an excellent piece of work!

My task language is Chinese. I have trained the MultiCXR model on my own vocabulary, I have the following problems when training the GTPrompt model:

I cannot load the multi_ckpt_name: aehrc/cxrmate-multi-tf you trained, because the word embedding dimension size is different, and the cxrmate-multi-tf-cn I trained myself did not save the model file in the pytorch_model.bin format, so I don’t know how to load it.

How should I load the trained MultiCXR model in.ckpt format.

# Load multi checkpoint:
if encoder_decoder_ckpt_name:
    encoder_decoder = AutoModel.from_pretrained(encoder_decoder_ckpt_name, trust_remote_code=True)
    self.load_state_dict(encoder_decoder.state_dict())
else:
    warnings.warn('The encoder-to-decoder model was not warm-started before applying low-rank approximation.')

Model migration consultation

Hi! Thanks for your contribution. It is an excellent piece of work!

My task language is Chinese. I have trained a Chinese tokenizer and trained it from scratch, but I have the following questions:
Can I still use CheXbert metrics? I am still using monitor: val_report_chexbert_f1_macro for my training. Should I change to other monitor?

Thank you very much for your time and consideration. I eagerly look forward to your response.

error when running cxrmate.ipynb

Hi, I am trying to run through the example code in cxrmate.ipynb. When I get to this line:

outputs = encoder_decoder.generate(
    pixel_values=images.to(device),
    decoder_input_ids=prompt['input_ids'],
    special_token_ids=[
        tokenizer.additional_special_tokens_ids[
            tokenizer.additional_special_tokens.index('[PMT-SEP]')
        ],
        tokenizer.bos_token_id,
        tokenizer.sep_token_id,
    ],  
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    mask_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256 + prompt['input_ids'].shape[1],
    num_beams=4,
)

I get the following error:

Traceback (most recent call last):

  Cell In[11], line 1
    outputs = encoder_decoder.generate(

  File ~\AppData\Local\anaconda3\lib\site-packages\torch\utils\_contextlib.py:115 in decorate_context
    return func(*args, **kwargs)

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\generation\utils.py:1593 in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\generation\utils.py:742 in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

  File ~\AppData\Local\anaconda3\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl
    return forward_call(*args, **kwargs)

  File ~\.cache\huggingface\modules\transformers_modules\aehrc\cxrmate\1f014633b98564f21316b32e167b5796381690d8\modelling_longitudinal.py:91 in forward
    return ModelOutputWithProjectionEmbedding(

  File ~\AppData\Local\anaconda3\lib\site-packages\transformers\utils\generic.py:325 in __init__
    raise TypeError(

TypeError: transformers_modules.aehrc.cxrmate.1f014633b98564f21316b32e167b5796381690d8.modelling_longitudinal.ModelOutputWithProjectionEmbedding is not a dataclasss. This is a subclass of ModelOutput and so must use the @dataclass decorator.

How to test the trained model of .ckpt format

Hi! Thanks for your contribution. It is an excellent piece of work!

The trained model is saved in last.ckpt format, but it cannot be directly tested using the test command in the repository.

dlhpcstarter -t cxrmate_hf -c config/test_huggingface/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --test

What adjustments should I make?

Thank you very much for your time and consideration. I eagerly look forward to your response.

Request: notebook file for training cxrmate

This repo looks very interesting, I would like to try using it to train on a new large XR dataset. However the only training tutorial available is on the main github page, and requires command-line usage of the dlhpcstarter package. Could you please create an .ipynb file showing how to run training within a python IDE, similar to the examples found in cxrmate/examples for how to run inference-only? Thanks!

Model architecture adjustment problem

Hi! Thanks for your contribution. It is an excellent piece of work!

Your idea is great, and I want to test my task. But my corpus language is Chinese, do I need to adjust the tokenizer and pre-trained bert? Will it work?

Thank you very much for your time and consideration. I eagerly look forward to your response.

Authors name in path causes error in tutorial code

Hi, I tried to download this and run the first example here on github:

dlhpcstarter -t cxrmate_hf -c config/test_huggingface/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --test

But I get an error right away that suggests the code authors name is still hardcoded in a path somewhere? Does that need to be fixed? What is that paths.yaml file its referring to? Thanks!

(base) C:\Users\myusername\Desktop\cxrmate>dlhpcstarter -t cxrmate_hf -c config/test_huggingface/longitudinal_gen_prompt_cxr-bert.yaml --stages_module tools.stages --test
Traceback (most recent call last):
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "C:\Users\myusername\AppData\Local\anaconda3\Scripts\dlhpcstarter.exe\__main__.py", line 7, in <module>
    sys.exit(main())
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\dlhpcstarter\__main__.py", line 49, in main
    load_config_and_update_args(args=args, print_args=True)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\dlhpcstarter\utils.py", line 79, in load_config_and_update_args
    config = compose(config_name=args.config_name)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\compose.py", line 38, in compose
    cfg = gh.hydra.compose_config(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\hydra.py", line 594, in compose_config
    cfg = self.config_loader.load_configuration(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\config_loader_impl.py", line 142, in load_configuration
    return self._load_configuration_impl(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\config_loader_impl.py", line 253, in _load_configuration_impl
    defaults_list = create_defaults_list(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 745, in create_defaults_list
    defaults, tree = _create_defaults_list(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 715, in _create_defaults_list
    defaults_tree = _create_defaults_tree(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 356, in _create_defaults_tree
    ret = _create_defaults_tree_impl(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 457, in _create_defaults_tree_impl
    return _expand_virtual_root(repo, root, overrides, skip_missing)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 280, in _expand_virtual_root
    subtree = _create_defaults_tree_impl(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 573, in _create_defaults_tree_impl
    add_child(children, new_root)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 520, in add_child
    subtree_ = _create_defaults_tree_impl(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 573, in _create_defaults_tree_impl
    add_child(children, new_root)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 520, in add_child
    subtree_ = _create_defaults_tree_impl(
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 488, in _create_defaults_tree_impl
    config_not_found_error(repo=repo, tree=root)
  File "C:\Users\myusername\AppData\Local\anaconda3\lib\site-packages\hydra\_internal\defaults_list.py", line 799, in config_not_found_error
    raise MissingConfigException(
hydra.errors.MissingConfigException: In 'single_tf': Could not load '/home/anicolson/config/paths.yaml'.

How to find the prior images of the current image?

Hi! Thanks for your contribution. It is an excellent piece of work!

I would like to know how the prior images of the current image are identified in this paper.
Upon careful examination of the MIMIC-CXR-v2 dataset, I have observed that its documentation states, "These study identifiers are completely random, and their order has no implications for the chronological order of the actual studies. " Additionally, the time information in the reports is masked (e.g., "In comparison with the study of ___"). Therefore, I am uncertain about how to identify the prior images corresponding to a given current image.

Thank you very much for your time and consideration. I eagerly look forward to your response.

Unable to reproduce the result in ./examples/cxrmate.ipynb

I download the weight from https://huggingface.co/aehrc/cxrmate. And run ./examples. The result of cxrmate-multi-tf.ipynb is the same. But when I run cxrmate.ipynb, the result is different with that in ipynb.
First study result:
Findings: Frontal and lateral views of the chest were obtained. A large bore dual lumen central venous catheter terminates within the right internal jugular central venous catheter terminates within the right atrium. The lungs are unchanged. The heart size and right internal jugular central venous catheter ends in unchanged. The heart size is unchanged. The heart size is unchanged. The heart size is unchanged. The aorta remains mildly tortuous and the upper thoracic aorta remains mildly tortuous and tortuous and hilar contours are unchanged. The cardiac silhouette size is unchanged. The aorta remains mildly tortuous and tortuous and tortuous and tortuous and tortuous and hilar contours are unchanged. The aorta is unchanged. Increased interstitial abnormality is unchanged. The cardiac silhouette size is unchanged. The aorta is unchanged. The cardiac silhouette size is unchanged. There is unchanged. The aorta is unchanged. The aorta is unchanged. The cardiac silhouette size of the osseous structures are within normal. The aorta is unchanged. The cardiac silhouette is unchanged. The aorta is unchanged. The aorta is unchanged with atherosclerotic calcifications are unchanged. The aorta is unchanged. The aorta is tortuous and the osseous structures are within normal. The aorta calcified and tortuous atherosclerotic calcifications are within normal. The aorta is tortuous and tortuous and tortuous and tortuous. The aorta is calcified tortuous atherosclerotic calcifications are diffusely calcified thoracic aorta calcified and tortuous. The aorta
Impression:

Findings: Frontal and lateral views of the chest were obtained. A large bore dual lumen central venous catheter terminates within the right internal jugular central venous catheter terminates within the right atrium. The lungs are unchanged. The heart size and right internal jugular central venous catheter ends in unchanged. The heart size is unchanged. The heart size is unchanged. The heart size is unchanged. The aorta remains mildly tortuous and the upper thoracic aorta remains mildly tortuous and tortuous and hilar contours are unchanged. The cardiac silhouette size is unchanged. The aorta remains mildly tortuous and tortuous and tortuous and tortuous and tortuous and hilar contours are unchanged. The aorta is unchanged. Increased interstitial abnormality is unchanged. The cardiac silhouette size is unchanged. The aorta is unchanged. The cardiac silhouette size is unchanged. There is unchanged. The aorta is unchanged. The aorta is unchanged. The cardiac silhouette size of the osseous structures are within normal. The aorta is unchanged. The cardiac silhouette is unchanged. The aorta is unchanged. The aorta is unchanged with atherosclerotic calcifications are unchanged. The aorta is unchanged. The aorta is tortuous and the osseous structures are within normal. The aorta calcified and tortuous atherosclerotic calcifications are within normal. The aorta is tortuous and tortuous and tortuous and tortuous. The aorta is calcified tortuous atherosclerotic calcifications are diffusely calcified thoracic aorta calcified and tortuous. The aorta
Impression:

Seconde study result:
Findings: PA and the lungs are within normal. There is moderately tortuous and tortuous and the heart size is moderately tortuous and the lungs are within normal. There is moderately tortuous. There is moderately tortuous and the aortic knob is moderately tortuous. There is moderately tortuous. There is moderately tortuous and tortuous and the aorta is moderately tortuous. There is unchanged. There is a large bore central pulmonary vascularity is moderately tortuous and the aortic knob and tortuous and the aorta is unchanged. There is unchanged. There is unchanged. There is moderately tortuous. There is moderately tortuous. There is moderately tortuous. There is moderately tortuous. The cardiac silhouette is unchanged. Low lung volumes are within normal. The cardiac silhouette is moderately tortuous. There is unchanged. The cardiac silhouette is moderately tortuous. There is moderately tortuous. The cardiac silhouette is moderately tortuous. There is unchanged. There is unchanged. The cardiac silhouette is moderately tortuous and tortuous. Low lung volumes are within normal. There is unchanged. The cardiac silhouette is moderately tortuous. The cardiac silhouette is unchanged. The cardiac silhouette is unchanged. There is moderately tortuous and tortuous. There is unchanged. There is unchanged. There is moderately tortuous. The cardiac silhouette is moderately tortuous and tortuous and the knob calcifications are within normal. There is moderately tortuous and the atherosclerotic calcifications are within normal.
Impression:

Findings: The heart remains moderately enlarged. Low lung volumes are low. Low lung volumes are low. Low lung volumes are low. Low lung volumes are relatively low. Low lung volumes are low. Low lung volumes are low. Low lung volumes are low. Low lung volumes are low. Low lung volumes are low, and there is low, and there is low, and there is grossly clear. Low lung volumes are low, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, and there is low, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however, however,
Impression:

I wonder why? Otherwise, I pip install -r requirements.txt , the version of transformer is 4.42.1, in this version is 4.42, it will report:Traceback (most recent call last):
File "/tmp/cxrmate/examples/cxrmate-multi-tf.py", line 59, in
outputs = encoder_decoder.generate(
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 1597, in generate
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
File "/root/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 523, in _prepare_encoder_decoder_kwargs_for_generation
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'output_attentions'

I changed version to 4.30.1, it will not report an error.
I would like to know your Torch and Transformers versions, directly following the requirements will download the latest version.
I hope to receive your reply! Thank you very much!

Printing out accuracy metrics during training

I'm trying to recreate the training on a subset of the mimic-cxr-jpeg dataset and my output during training looks like this:

Epoch 10: 100%
TBTokenizer tokenized 2958 tokens at 43535.30 tokens per second.
PTBTokenizer tokenized 4446 tokens at 62878.10 tokens per second.
{'testlen': 2425, 'reflen': 3839, 'guess': [2425, 2365, 2305, 2245], 'correct': [1180, 456, 223, 112]}
ratio: 0.6316749153423726

Since the chexbert F1 metric is the parameter being used to decide when the final version of the model is saved during training, is there a way to print out this F1 metric after each epoch? That would give an idea for if the model is still improving, or if a bunch of epochs have gone by without any progress on validation F1. It could be interesting to print out other validation set metrics like bleu and meteor after each epoch as well.

Also, what is the "ratio" currently being printed out by the logs here? Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.