Code Monkey home page Code Monkey logo

nih-chest-x-rays-multi-label-image-classification-in-pytorch's Introduction

NIH-Chest-X-rays-Multi-Label-Image-Classification-In-Pytorch

Multi-Label Image Classification of the Chest X-Rays In Pytorch

Requirements

  • torch >= 0.4
  • torchvision >= 0.2.2
  • opencv-python
  • numpy >= 1.7.3
  • matplotlib
  • tqdm

Dataset

NIH Chest X-ray Dataset is used for Multi-Label Disease Classification of of the Chest X-Rays. There are a total of 15 classes (14 diseases, and one for 'No findings') Images can be classified as "No findings" or one or more disease classes:

  • Atelectasis
  • Consolidation
  • Infiltration
  • Pneumothorax
  • Edema
  • Emphysema
  • Fibrosis
  • Effusion
  • Pneumonia
  • Pleural_thickening
  • Cardiomegaly
  • Nodule Mass
  • Hernia

There are 112,120 X-ray images of size 1024x1024 pixels, in which 86,524 images are for training and 25,596 are for testing.

Sample X-Ray Images

Atelectasis
Cardiomegaly | Edema | Effusion
No Finding

Model

Pretrained Resnet50 model is used for Transfer Learning on this new image dataset.

Loss Function

There is a choice of loss function

  • Focal Loss (default)
  • Binary Cross Entropy Loss or BCE Loss

Training

  • From Scratch

    Following are the layers which are set to trainable-

    • layer2
    • layer3
    • layer4
    • fc

    Terminal Code:

    python main.py
    
  • Resuming From a Saved Checkpoint

    A Saved Checkpoint needs to be loaded which is nothing but a dictionary containing the

    • epochs (number of epochs the model has been trained till that time)

    • model (architecture and the learnt weights of the model)

    • lr_scheduler_state_dict (state_dict of the lr_scheduler)

    • losses_dict (a dictionary containing the following loses)

      • mean train epoch losses for all the epochs
      • mean val epoch losses for all the epochs
      • batch train loss for all the training batches
      • batch train loss for all the val batches

Different layers of the model are freezed/unfreezed in different stages, defined at the end of *this README.md file, to fit the model well on the data. The 'stage' parameter can be passed from the terminal using the argument --stage STAGE

Terminal Code:

python main.py --resume --ckpt checkpoint_file.pth --stage 2

Training the model will create a models directory and will save the checkpoints in there.

Testing

A Saved Checkpoint needs to be loaded using the --ckpt argument and --test argument needs to be passed for activating the Test Mode

Terminal Code:

python main.py --test --ckpt checkpoint_file.pth

Result

The model achieved the average ROC AUC Score of 0.73241 on all classes(excluding "No findings" class) after training in the following stages-

STAGE 1

  • Loss Function: FocalLoss
  • lr: 1e-5
  • Training Layers: layer2, layer3, layer4, fc
  • Epochs: 2

STAGE 2

  • Loss Function: FocalLoss
  • lr: 3e-4
  • Training Layers: layer3, layer4, fc
  • Epochs: 1

STAGE 3

  • Loss Function: FocalLoss
  • lr: 1e-3
  • Training Layers: layer4, fc
  • Epochs: 3

STAGE 4

  • Loss Function: FocalLoss
  • lr: 1e-3
  • Training Layers: fc
  • Epochs: 2

nih-chest-x-rays-multi-label-image-classification-in-pytorch's People

Contributors

n0obcoder avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.