Code Monkey home page Code Monkey logo

diffad's Introduction

DiffAD

[ICCV2023] Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model

@inproceedings{zhang2023unsupervised,
  title={Unsupervised Surface Anomaly Detection with Diffusion Probabilistic Model},
  author={Zhang, Xinyi and Li, Naiqi and Li, Jiawei and Dai, Tao and Jiang, Yong and Xia, Shu-Tao},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={6782--6791},
  year={2023}
}

Method overview

image

Installation

conda env create -f environment.yaml
conda activate DiffAD

Dataset

Following DRAEM, we use the MVTec-AD and DTD dataset. You can run the download_dataset.sh script from the project directory to download the MVTec and the DTD datasets to the datasets folder in the project directory:

./scripts/download_dataset.sh

Training

Reconstruction sub-network

The reconstrucion sub-network is based on the latent diffusion model.

Training Auto-encoder

cd rec_network
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/kl.yaml -t --gpus 0,  

Training LDMs

CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/mvtec.yaml -t --gpus 0 -max_epochs 4000, 

Discriminative sub-network

cd seg_network
CUDA_VISIBLE_DEVICES=<GPU_ID> python train.py --gpu_id 0 --lr 0.001 --bs 32 --epochs 700 --data_path ./datasets/mvtec/ --anomaly_source_path ./datasets/dtd/images/ --checkpoint_path ./checkpoints/obj_name --log_path ./logs/

Evaluating

Reconstrucion performance

After training the reconstruction sub-network, you can test the reconstruction performance with the anomalous inputs:

python scripts/mvtec.py

For some samples with severe deformations, such as missing transistors, you can add some noise to the anomalous conditions to adjust the sampling.

Anomaly segmentation

cd seg_network
python test.py --gpu_id 0 --base_model_name "seg_network" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/obj_name/

diffad's People

Contributors

loco-roco avatar

Stargazers

 avatar Algernon avatar XiaohuLiu avatar  avatar  avatar Taewoo Kim avatar  avatar Yuhan Wang avatar Mazeqi avatar TOMCAT avatar Niefengxxx avatar  avatar ALBERT avatar Matic Fučka avatar livic avatar lishugang avatar

Watchers

Kostas Georgiou avatar  avatar  avatar

diffad's Issues

i cannot find max_epochs.

{'_default_root_dir': '/data/DJL/DiffAD-main',
'_fit_loop': <pytorch_lightning.loops.fit_loop.FitLoop object at 0x7fa6575dd7c0>,
'_is_data_prepared': False,
'_lightning_optimizers': None,
'_predict_loop': <pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop object at 0x7fa6575f1820>,
'_progress_bar_callback': <pytorch_lightning.callbacks.progress.ProgressBar object at 0x7fa6575f1ac0>,
'_stochastic_weight_avg': False,
'_test_loop': <pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop object at 0x7fa6575f1520>,
'_validate_loop': <pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop object at 0x7fa6575f1220>,
'_weights_save_path': '/data/DJL/DiffAD-main',
'accelerator_connector': <pytorch_lightning.trainer.connectors.accelerator_connector.AcceleratorConnector object at 0x7fa6575dd340>,
'accumulate_grad_batches': 2,
'accumulation_scheduler': <pytorch_lightning.callbacks.gradient_accumulation_scheduler.GradientAccumulationScheduler object at 0x7fa6575f1b50>,
'auto_lr_find': False,
'auto_scale_batch_size': False,
'callback_connector': <pytorch_lightning.trainer.connectors.callback_connector.CallbackConnector object at 0x7fa6575dd550>,
'callbacks': [<main.SetupCallback object at 0x7fa6575c5760>,
<main.ImageLogger object at 0x7fa6575c5730>,
<pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor object at 0x7fa6575c5af0>,
<main.CUDACallback object at 0x7fa6575c57f0>,
<pytorch_lightning.callbacks.progress.ProgressBar object at 0x7fa6575f1ac0>,
<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7fa6575d6c70>],
'check_val_every_n_epoch': 1,
'checkpoint_connector': <pytorch_lightning.trainer.connectors.checkpoint_connector.CheckpointConnector object at 0x7fa6575dd6a0>,
'config_validator': <pytorch_lightning.trainer.configuration_validator.ConfigValidator object at 0x7fa6575dd220>,
'data_connector': <pytorch_lightning.trainer.connectors.data_connector.DataConnector object at 0x7fa6575dd2b0>,
'datamodule': None,
'debugging_connector': <pytorch_lightning.trainer.connectors.debugging_connector.DebuggingConnector object at 0x7fa6575dd5e0>,
'dev_debugger': <pytorch_lightning.utilities.debugging.InternalDebugger object at 0x7fa6575dd100>,
'fast_dev_run': False,
'flush_logs_every_n_steps': 100,
'gradient_clip_algorithm': <GradClipAlgorithmType.NORM: 'norm'>,
'gradient_clip_val': 0.0,
'limit_predict_batches': 1.0,
'limit_test_batches': 1.0,
'limit_train_batches': 1.0,
'limit_val_batches': 1.0,
'log_every_n_steps': 50,
'logger': <pytorch_lightning.loggers.test_tube.TestTubeLogger object at 0x7fa6f0140fa0>,
'logger_connector': <pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector object at 0x7fa6575dd040>,
'model_connector': <pytorch_lightning.trainer.connectors.model_connector.ModelConnector object at 0x7fa6575dd4c0>,
'move_metrics_to_cpu': False,
'num_predict_batches': [],
'num_sanity_val_batches': [],
'num_sanity_val_steps': 2,
'num_test_batches': [],
'num_training_batches': 0,
'num_val_batches': [],
'optimizer_connector': <pytorch_lightning.trainer.connectors.optimizer_connector.OptimizerConnector object at 0x7fa6575dd2e0>,
'overfit_batches': 0.0,
'predicted_ckpt_path': None,
'prepare_data_per_node': True,
'profiler': <pytorch_lightning.profiler.base.PassThroughProfiler object at 0x7fa6575dd190>,
'reload_dataloaders_every_n_epochs': 0,
'should_stop': False,
'shown_warnings': set(),
'slurm_connector': <pytorch_lightning.trainer.connectors.slurm_connector.SLURMConnector object at 0x7fa6575dd700>,
'state': TrainerState(status=<TrainerStatus.INITIALIZING: 'initializing'>, fn=None, stage=None),
'terminate_on_nan': False,
'test_dataloaders': None,
'tested_ckpt_path': None,
'track_grad_norm': -1.0,
'train_dataloader': None,
'training_tricks_connector': <pytorch_lightning.trainer.connectors.training_trick_connector.TrainingTricksConnector object at 0x7fa6575dd640>,
'truncated_bptt_steps': None,
'tuner': <pytorch_lightning.tuner.tuning.Tuner object at 0x7fa6575dd760>,
'val_check_interval': 1.0,
'val_dataloaders': None,
'validated_ckpt_path': None,
'verbose_evaluate': True,
'weights_summary': 'top'}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.