Code Monkey home page Code Monkey logo

colorization-pytorch's Introduction

Interactive Deep Colorization in PyTorch

Project Page | Paper | Video | Talk | UI code

Real-Time User-Guided Image Colorization with Learned Deep Priors.
Richard Zhang*, Jun-Yan Zhu*, Phillip Isola, Xinyang Geng, Angela S. Lin, Tianhe Yu, and Alexei A. Efros.
In ACM Transactions on Graphics (SIGGRAPH 2017).

This is our PyTorch reimplementation for interactive image colorization, written by Richard Zhang and Jun-Yan Zhu.

This repository contains training usage. The original, official GitHub repo (with an interactive GUI, and originally Caffe backend) is here. The official repo has been updated to support PyTorch models on the backend, which can be trained in this repository.

Prerequisites

  • Linux or macOS
  • Python 2 or 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

pip install -r requirements.txt
  • Clone this repo:
git clone https://github.com/richzhang/colorization-pytorch
cd colorization-pytorch

Dataset preparation

  • Download the ILSVRC 2012 dataset and run the following script to prepare data python make_ilsvrc_dataset.py --in_path /PATH/TO/ILSVRC12. This will make symlinks into the training set, and divide the ILSVRC validation set into validation and test splits for colorization.

Training interactive colorization

  • Train a model: bash ./scripts/train_siggraph.sh. This is a 2 stage training process. First, the network is trained for automatic colorization using classification loss. Results are in ./checkpoints/siggraph_class. Then, the network is fine-tuned for interactive colorization using regression loss. Final results are in ./checkpoints/siggraph_reg2.

  • To view training results and loss plots, run python -m visdom.server and click the URL http://localhost:8097. The following values are monitored:

    • G_CE is a cross-entropy loss between predicted color distribution and ground truth color.
    • G_entr is the entropy of the predicted distribution.
    • G_entr_hint is the entropy of the predicted distribution at points where a color hint is given.
    • G_L1_max is the L1 distance between the ground truth color and argmax of the predicted color distribution.
    • G_L1_mean is the L1 distance between the ground truth color and mean of the predicted color distribution.
    • G_L1_reg is the L1 distance between the ground truth color and the predicted color.
    • G_fake_real is the L1 distance between the predicted color and the ground truth color (in locations where a hint is given).
    • G_fake_hint is the L1 distance between the predicted color and the input hint color (in locations where a hint is given). It's a measure of how much the network "trusts" the input hint.
    • G_real_hint is the L1 distance between the ground truth color and the input hint color (in locations where a hint is given).

Testing interactive colorization

  • Get a model. Either:

    • (1) download the pretrained model by running bash pretrained_models/download_siggraph_model.sh, which will give you a few models.
      • Original caffe weights [Recommended] ./checkpoints/siggraph_caffemodel/latest_net_G.pth is the original caffemodel weights, converted to PyTorch. It is recommended. Be sure to set --mask_cent 0 when running it.
      • Retrained model: ./checkpoints/siggraph_retrained/latest_net_G.pth. The model achieves better PSNR but performs qualitatively differently. Note that this repository is an approximate reimplementation of the siggraph paper.
    • (2) train your own model (as described in the section above), which will leave a model in ./checkpoints/siggraph_reg2/latest_net_G.pth
  • Test the model on validation data:

    • python test.py --name siggraph_caffemodel --mask_cent 0 for original caffemodel weights
    • python test.py --name siggraph_retrained for retrained weights.
    • python test.py --name siggraph_reg2 if you retrained your own model The test results will be saved to an HTML file in ./results/[[NAME]]/latest_val/index.html. For each image in the validation set, it will test (1) automatic colorization, (2) interactive colorization with a few random hints, and (3) interactive colorization with lots of random hints.
  • Test the model by making PSNR vs. the number of hints plot: python test_sweep.py --name [[NAME]] . This plot was used in Figure 6 of the paper. This test randomly reveals 6x6 color hint patches to the network and sees how accurate the colorization is with respect to the ground truth.

  • Test the model interactively with the original official repository. Follow installation instructions in that repo and run python ideepcolor.py --backend pytorch --color_model [[PTH/TO/MODEL]] --dist_model [[PTH/TO/MODEL]].

Citation

If you use this code for your research, please cite our paper:

@article{zhang2017real,
  title={Real-Time User-Guided Image Colorization with Learned Deep Priors},
  author={Zhang, Richard and Zhu, Jun-Yan and Isola, Phillip and Geng, Xinyang and Lin, Angela S and Yu, Tianhe and Efros, Alexei A},
  journal={ACM Transactions on Graphics (TOG)},
  volume={9},
  number={4},
  year={2017},
  publisher={ACM}
}

Acknowledgments

This code borrows heavily from the pytorch-CycleGAN repository.

colorization-pytorch's People

Contributors

alanyee avatar andersasa avatar andyli avatar gdlg avatar guopzhao avatar iver56 avatar jpmerc avatar junyanz avatar lambdawill avatar layumi avatar levirve avatar mengcz13 avatar naruto-sasuke avatar pertence avatar phillipi avatar richzhang avatar ruotianluo avatar simontreu avatar ssnl avatar strob avatar taesungp avatar tariqahassan avatar tylercarberry avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

colorization-pytorch's Issues

Classification output in SIGGRAPHGenerator, #channel=529 ?

Hi,

I'm reading the implementation detail of the SIGGRAPHGenerator, but had some trouble understanding the classification output of this network.
Classification output in SIGGRAPHGenerator:

model_class=[nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias),]

Is 529 the number of quantized color Q? While I suppose this number to be 313 instead...
Could anyone please explain this number?

Thanks,

Model training Loss graph

newplot

Plot.ly Link: https://plot.ly/~Ugness/1/

I am training the model with ILSVRC2012 training set and same options as your implementation and my loss graph looks like above.
I am afraid that my model's loss reduces correctly. Can you check this loss graph or share your loss graph?
Thanks.

I run this project on Colab,Then I encounter a problem

The error information is No such file or directory: './dataset/ilsvrc2012/val/1/ILSVRC2012_val_00000059.JPEG',But I can find this image in the directory,Your prompt attention to my question is appreciated,Thanks for your help!!

Errors when running CPU only

Im trying to run using CPU only in a VM. I keep getting the following error during the first part of the batch script:

`[Network G] Total number of parameters : 34.187 M

No handlers could be found for logger "visdom"
create web directory ./checkpoints/siggraph_class_small/web...
Traceback (most recent call last):
File "train.py", line 61, in
model.optimize_parameters()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 193, in optimize_parameters
self.forward()
File "/home/testing/Desktop/colorization-pytorch-master/models/pix2pix_model.py", line 123, in forward
self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
File "/home/testing/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'
mkdir: cannot create directory ‘./checkpoints/siggraph_class’: File exists
cp: cannot stat './checkpoints/siggraph_class_small/latest_net_G.pth': No such file or directory
`

Where are the local and global hints networks defined?

I cannot find a reference to either the local or global hints networks anywhere in the source code. I've had a look at the definition of the SIGGRAPHGenerator class, but that seems to be just the main colorization network. I would be very grateful for pointing me to the relevant lines. Thank you.

decode_ind_ab

in utils decode_ind_ab(), the calculations is

    data_a = data_q/opt.A
    data_b = data_q - data_a*opt.A
    data_ab = torch.cat((data_a, data_b), dim=1)

however I believe according to how the encoding was done we should have instead something like
data_a = (data_q - data_b)/opt.A

I'm imagining this would have to be solved using linear programming or something.
I was just wondering if this is something you're aware of and whether I am missing something?

My issue is that when I use decode_ind_ab currently all my b values come through as -1 as
with

    data_a = data_q/opt.A (eq1)
    data_b = data_q - data_a*opt.A (eq2)

we can sub eq1 into eq2 to show that

data_b = data_q - data_q = 0

which then gets scaled and shifted to -1 before being returned.

I found the error in util.py

In the line number 248, average patch values are not appropriately computed.
in utils add_color_patches_rand_gt() (line number 248), the calculations is
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
however, I believe that this code should be changed to
torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=-1,keepdim=True).view(1,C,1,1)

The previous code calculates the mean value within the ab channel, while the second calculates within the patch values.

Mac running test.py error

I tried to running the test.py but got this error. Any thoughts?

Traceback (most recent call last):
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/test.py", line 61, in <module>
    model.test(True)  # True means that losses will be computed
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/base_model.py", line 56, in test
    self.forward()
  File "/Users/spikeyuan/PycharmProjects/pythonProject6/colorization-pytorch/models/pix2pix_model.py", line 123, in forward
    self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
  File "/opt/anaconda3/envs/pythonProject6/lib/python2.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
    type(self).__name__, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'module'

Discriminator in the pretrained model

So I can see that there is a discriminator defined in the graph, if lambda_ GAN > 0. Do you guys have pretrained models that use the discriminator? I am curious about the difference in performance with and without discriminator?
Or what are hyperparameters if providing a pretrained model is not possible?

Custom Dataset

The model is a train on ILSVRC2012. Could someone please help me out in clarifying how to train it on my custom dataset?

test.py is not compatible with python 3

Hi, I am trying to test your model in google colab. I put a dummy image a dummy folder:
'/content/colorization-pytorch/dataset/ilsvrc2012/val/MyImages'

I tested these two following commands:
#!python test.py --name siggraph_caffemodel --mask_cent 0
!python test.py --name siggraph_retrained

Both led to the same error:

Traceback (most recent call last):
File "test.py", line 53, in
data_raw[0] = util.crop_mult(data_raw[0], mult=8)
File "/content/colorization-pytorch/util/util.py", line 277, in crop_mult
return data[:,:,h:h+Hnew,w:w+Wnew]
TypeError: slice indices must be integers or None or have an index method

I fixed it by casting indices to integers. Then I got another error in test.py line 57:
img_path = [string.replace('%08d_%.3f' % (i, sample_p), '.', 'p')]

string.replace is deprecated in python 2.7, and does not exists in python 3, see here: https://docs.python.org/2.7/library/string.html?highlight=string%20replace#string.replace

However, when I change my google colab execution to python 2, test.py works like a charm.

num_threads or nThreads?

I've noticed that the base options for this repository uses opt.nThreads, but I can't actually see anywhere this is used.

I've also done some digging and in the data folder init.py the CustomDatasetLoader calls opt.num_threads, which I believe is what the pix2pix repo uses in their base options, although I don't think this is used either in this project?

Is the data loading for this project actually multi-threaded?
In train.py torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True) is called, should I change this to include num_workers=int(opt.nThreads) if I want to speed up the data loading?

Thanks for the amazing code/project! :)

Global hints network

Hello, Mr Zhang. I've read this source code, but I can't find the global hint module. In the pix2pix_model.py file, the forward module is as following:

def forward(self):
        (self.fake_B_class, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
        self.fake_B_dec_max = self.netG.module.upsample4(util.decode_max_ab(self.fake_B_class, self.opt))
        self.fake_B_distr = self.netG.module.softmax(self.fake_B_class)

        self.fake_B_dec_mean = self.netG.module.upsample4(util.decode_mean(self.fake_B_distr, self.opt))

        self.fake_B_entr = self.netG.module.upsample4(-torch.sum(self.fake_B_distr * torch.log(self.fake_B_distr + 1.e-10), dim=1, keepdim=True))
        # embed()

The local hint module is applied, but I can't find the global hints module, and the ground truth of color distribution is also not calculated. Maybe I miss some important things, can you provide me some information? Thank you!

Training tends to yellowish colors

Hi,
Thanks for this Pytorch implementation. I'm following your exact tutorial to train the model on my Mac machine CPU, the problem is that as I monitor the visdom console, I see the output fake_reg image goes more yellowish as the losses curve goes down. I've had this same exact issue while training the original colorization model (i.e. the one without hints) introduced by the paper authors. I don't know why this actually happens, any ideas?

Thank you!

'SIGGRAPHGenerator' object has no attribute 'model'

when running on MAC(python3.5.4),I encountered the following problem:'SIGGRAPHGenerator' object has no attribute 'model'
I wonder how to fix it?
Traceback (most recent call last):
File "test.py", line 40, in
model.setup(opt)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 42, in setup
self.load_networks(opt.which_epoch)
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 136, in load_networks
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
File "/Users/joanna/Desktop/colorization/colorization-pytorch-master/models/base_model.py", line 116, in __patch_instance_norm_state_dict
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
File "/Users/joanna/opt/anaconda3/envs/color/lib/python3.5/site-packages/torch/nn/modules/module.py", line 576, in getattr
type(self).name, name))
AttributeError: 'SIGGRAPHGenerator' object has no attribute 'model'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.