leedoyup / deblurgan-tf Goto Github PK
View Code? Open in Web Editor NEWUnofficial tensorflow (tf) implementation of DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks
License: MIT License
Unofficial tensorflow (tf) implementation of DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks
License: MIT License
The program is stopped when I tried to initialize global variables.
global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
and
for iter in range(epoch)
are both used.
Have to be consistent. (for managing checkpoint file)
Dataset forms have to be determined.
It have to consider below codes
from data.dataloader import dataloader
dataset = dataloader(path)
dataset = [Num of Pair Images, Image Pair] .
for i, data in enumerate(dataset):
In here, "data" means data pair. Thus
data.blur_img, data.real_img
have to be available.
Will you upload pre-trained model?I saw the old issue,but it was long ago.
Now, functions related with convnet in model/ops.py have various policy, which treat variable scope.
It have to be fix
Is the argumen batch_num=1
in the train.py
equivalent to batch_size?
In test.py file
logging.info("%s image deblur starts", %data)
^
SyntaxError: invalid syntax
code is commented in train.py and is it useful?
logging.info("[!] Generator Optimization Start")
#for j in range(args.iter_gen):
if i % 5 == 0 :
feed_dict_G = {model.input['blur_img']: blur_img,
model.input['real_img']: real_img,
model.learning_rate: learning_rate}
loss_G, adv_loss, perceptual_loss, G_out = model.run_optim_G(feed_dict=feed_dict_G,
with_loss=True, with_out=True)
logging.info('%d epoch, %d batch, Generator Loss: %f, add loss: %f, perceptual_loss: %f', iter, i, loss_G, adv_loss, perceptual_loss)
batch_loss_G +=loss_G
#logging: time, loss
#Ready for Training Discriminator
feed_dict_G = {model.input['blur_img']: blur_img}
G_out = model.G_output(feed_dict=feed_dict_G)
x_hat = model.sess.run(get_x_hat(G_out, real_img, args.batch_num))
feed_dict_D = {model.input['gen_img']: G_out,
model.input['real_img']: real_img,
model.input['x_hat']: x_hat,
model.learning_rate: learning_rate}
logging.info("[!] Discriminator Optimization Start")
#for j in range(args.iter_disc):
loss_D = model.run_optim_D(feed_dict=feed_dict_D, with_loss=True)
print(loss_D)
batch_loss_D += loss_D
#logging: time, loss
logging.info('%d epoch, %d batch, Discriminator Loss: %f', iter, i, loss_D)
batch_time = time.time() - start_time
#logging
Will you upload the pre-trained model?
Thank you.
When I type Python train.py ,an error occurred:IndexError: list index out of range,I want to know how to correct. So can you show me your folder?
Thank you for your code. It is very help for me. But i meet a question. I have download the GOPRO dataset, There was a problem when I used this dataset:
“A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 253f8cb515780f3b799900260a226db6 so we will re-download the data.”
This dataset is downloaded from your code page. I'm not sure where I went wrong.
I look forward to your reply.
Hello,
Why are there fully connected layers at the end of the discriminator? Isn't the original PatchGAN architecture fully convolutional? I looked at the pytorch implementation of original DeblurGAN and could not see the fc layers. Sorry in advance if I am overlooking something in both your and original deblurgan implementations.
Thank you.
Thank you for your excellent work! I had been searching for tf version of DeblurGAN for months.
I find a bug in "models/cgan_model.py" Line 44.
self.D = discriminator(tf.concat([self.G, self.input['real_img']], axis=0))
which means that the input of discriminator is "[deblurred image from G, sharp image]", where real_img = sharp image.
It will effect Line 100, computing the adversarial loss,
self.adv_loss = adv_loss(self.D)
However in Kupyn's Pytorch code, "models/conditional_gan_model.py" Line 95, the corresponding code of Line 44 and Line 100 is
self.loss_G_GAN = self.discLoss.get_g_loss(self.netD, self.real_A, self.fake_B)
.
which means the input of discriminator is "[deblurred image from G, blurry image]", since self.real_A = blurry image and the self.fake_B = deblurred image from G.
That is is how Kupyn generates the adversarial loss.
So, the correct Line 44 should be:
self.D = discriminator(tf.concat([self.G, self.input['blur_img']], axis=0))
From the point of view of Conditional GAN, the blurry image is the auxiliary information for both generator and discriminator. In other words, the blurry image is the information we condition on.
the input of the discriminator should be [G(blurry), blurry] or [real, blurry].
Hi, thank you for sharing the code
When i run the code, it occupies all of the memory of 4 GPUs. But only one GPU is involved in the calculation. (my work computer has 4 1080Ti).
Could you reply the question?
Thank you
from future import print_function
import time
import os
import sys
import logging
import re
import tensorflow as tf
import numpy as np
from data.data_loader import *
from models.cgan_model import cgan
def build_model(args):
sess = tf.Session()
model = cgan()
def main(args):
config = json.load(open(args.config), 'r')
sess = tf.Session()
model = cgan(sess, args)
model.build_model()
model.load_weights(args.checkpoint_dir)
dataset = glob.glob(os.path.join(args.data_path_t, '*.'+args.img_type))
if not os.path.exist(args.result_dir):
os.mkdir(args.result_dir)
for i, data in enumerate(dataset):
logging.info("%s image deblur starts" +data)
blur_img = read_image(data)
logging.debug("%s image was loaded" +data)
feed_dict_G = {model.input['blur_img']: blur_img}
G_out = model.G_output(feed_dict=feed_dict_G)
logging.debug("The image was converted")
logging.deug(G_out)
cv2.imwrite(os.path.join(args.result_dir, str(i)+'_blur.png'), blur_img)
cv2.imwrite(os.path.join(args.result_dir, str(i)+'_result.png'), G_out)
logging.info("Image save was completed")
#load save checkpoint files
#for i:end of blur image
#run generator with blur image input
#save result image
#iterate until blur image end
if name == 'main':
import argparse
parser = argparse.ArgumentParser(description='')
parser.parse_args('--is_training', action='store_true')
parser.add_argument('-c', '--conf', type=str, default='configs/ilsvrc_sample.json')
parser.add_argument('--iter_gen', type=int, default=5)
parser.add_argument('--iter_disc', type=int, default=1)
parser.add_argument('--batch_num', type=int, default=1)
parser.add_argument('--data_path', type=str, default='/data/private/data/GOPRO_Large/train/')
parser.add_argument('--data_path_t', type=str, default='./test_data/')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
parser.add_argument('--model_name', type=str, default='DeblurGAN.model')
parser.add_argument('--summary_dir', type=str, default='./summaries/')
parser.add_arguemnt('--data_name', type=str, default='GOPRO')
parser.add_argument('--result_dir', type=str, default='./result_dir')
parser.add_argument('--debug', action='store_true')
parset.add_argument('--img_type', type=str, default='png')
args = parser.parse_args()
log_format = '[%(asctime)s %(levelname)s] %(message)s'
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=level, format=log_format, stream=sys.stderr)
logging.getLogger("cgan.*").setLevel(level)
main(args)
Issue:
parser.parse_args('--is_training', action='store_true')
TypeError: parse_args() got an unexpected keyword argument 'action'
HELLLO @LeeDoYup
I have 1000 image pairs,GTX 1060 GPU,how long will it take to get finished training?
Thank you for sharing your implementation!! Can you also share the checkpoints you got?
The names of variable of generator and discriminator have to be discriminated.
Discriminator loss nan, why?
[2018-11-22 16:53:22,014 INFO] 0 epoch, 152 batch, Discriminator Loss: 11.182838
[2018-11-22 16:53:22,109 INFO] [!] Generator Optimization Start
[2018-11-22 16:53:22,978 INFO] [!] Discriminator Optimization Start
[2018-11-22 16:53:23,113 INFO] 0 epoch, 153 batch, Discriminator Loss: nan
nan
[2018-11-22 16:53:23,213 INFO] [!] Generator Optimization Start
[2018-11-22 16:53:24,131 INFO] [!] Discriminator Optimization Start
nan
[2018-11-22 16:53:24,265 INFO] 0 epoch, 154 batch, Discriminator Loss: nan
[2018-11-22 16:53:24,366 INFO] [!] Generator Optimization Start
[2018-11-22 16:53:24,435 INFO] 0 epoch, 155 batch, Generator Loss: nan, add loss: nan, perceptual_loss: 0.003620
[2018-11-22 16:53:25,444 INFO] [!] Discriminator Optimization Start
[2018-11-22 16:53:25,580 INFO] 0 epoch, 155 batch, Discriminator Loss: nan
nan
[2018-11-22 16:53:25,685 INFO] [!] Generator Optimization Start
[2018-11-22 16:53:26,591 INFO] [!] Discriminator Optimization Start
nan
The loss function of critic is not valid.
It have to use WGAN-GP, but it utilize only improved-WGAN loss now.
Have to be fixed.
when I use list(pair_path) after for statement, the result becomes empty.
Now, it assign
result = list(pair_path)
before for blur, real in iter_pair_path:
.
why?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.