Code Monkey home page Code Monkey logo

cmpt-762-assignment-2-code's Introduction

CIFAR-100 Classification with DenseNet

Project requirements: https://docs.google.com/document/d/1ZUdQ7c2X7y_KWobwZG0y0uFpwe__K1Hs1I_rv5jR3ws/edit?usp=sharing

Prerequisites

  • Python 3.8 or later
    • You may remove typing.Final type hints to use this project with earlier Python versions
  • Install all requirements in requirements.txt
    • Using pip: pip install -r requirements.txt

Directory Structure

  • constant.py: Constants
  • dataset.py: The CIFAR_SFU dataset, copied from the handout notebook, unmodified
  • densenet.py: The DenseNet model
  • epoch_logger.py: Logs the train/validation losses/accuracies into 2 files: train_log.csv and validation_log.csv
  • epoch_visualizer.py: Generates the training loss & validation plot
    • Copied from the handout notebook
    • Slightly modified to show fewer x labels (showing 400 x labels for 400 epochs is unfeasible)
  • infer.py: Generates the Kaggle submission csv using the test set
  • train.py: The training script
  • Files with filenames ending with _test are Python unit tests

Most scripts are using the argparse module to parse their command line arguments. Use -h on a script to view all available arguments.

Train

Run the train.py script with arguments specifying where to save the checkpoints and logs.

python train.py --checkpoint_save_dir=checkpoints --train_log_filename=checkpoints/train_log.csv --validation_log_filename=checkpoints/validation_log.csv
  • Network weights are saved in checkpoint_save_dir
  • Training and validation logs (accuracies, losses, etc.) are saved in train_log_filename and validation_log_filename respectively

Training takes about 4 hours on a machine with i7-9700K and RTX 2080.

Visualize Training Loss and Validation Accuracy

The epoch_visualizer.py script reads checkpoints/train_log.csv and checkpoints/validation_log.csv to generate a plot of training losses and validation accuracies. It then saves it as checkpoints/plot.png.

Infer

Run the infer.py script with arguments specifying the checkpoint location and prediction csv location.

python infer.py --checkpoint_filename=../0.4-128batch/400.pth --csv_filename=predictions.csv
  • Loads network weights from checkpoint_filename
  • Saves test set predictions in csv_filename

Pre-trained Network Weights

Pre-trained network weights are available at: https://github.com/MacJim/CMPT-762-Assignment-2-Checkpoints

The 0.4-128batch/400.pth weights have the best test accuracy 0.76500 (runner-up in our Kaggle competition).

References

cmpt-762-assignment-2-code's People

Contributors

macjim avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.