Code Monkey home page Code Monkey logo

deep-image-matting-pytorch's Introduction

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




deep-image-matting-pytorch's People

Contributors

foamliu avatar wrrjasmine avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

deep-image-matting-pytorch's Issues

Run trained model on CPU

I try to run eval.py on my laptop that doesn't have cuda and I got the following error:

Traceback (most recent call last):
  File "eval.py", line 62, in <module>
    pred = model(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/data_parallel.py", line 146, in forward
    "them on device: {}".format(self.src_device_obj, t.device))
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

I also modified the load state code to
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
but it doesn't fix all issues.

Performance

Can you achieve the performance that the paper claims in your experiment?

evaluate problem

Is the evaluation progress like
input 1920 * 1080 alpha 1920 * 1080 trimap 1920 *1080
then resize input and trimap to 320 * 320
input 320 * 320 trimap 320 * 320 to the model, then get the prediction 320 * 320
prediction 320 * 320, resize to 1920 * 1080
pred 1920 * 1080 alpha 1920 * 1080
finally compare the resized prediction and the alpha(gt)


i test on the DIM-tensorflow version and cannot get the number in paper, because the sad,mse,conn and grad are effected by the image size
Thanks

trimaps used for testing

I saw you used dilation/erosion to generate trimaps during testing. However, Adobe provided trimaps which should be used for testing for fair comparison. Correct me if I am wrong. thanks.

How can i get the Trimaps of my pictures?

Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

Regarding Compositional Loss

Did you observe any degradation in performance by just using alpha_loss, not total loss or compositional loss?

unable to start training using pretrained weigths

whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

/usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
File "train.py", line 180, in
main()
File "train.py", line 176, in main
train_net(args)
File "train.py", line 71, in train_net
logger=logger)
File "train.py", line 112, in train
alpha_out = model(img) # [N, 3, 320, 320]
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
if t.device != self.src_device_obj:
File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr
type(self).name, name))
AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

Three changes compared to keras version v1

I noticed 3 insteresting changes in this version compared to v1:

  1. "compositional_loss" and "overall_loss" are removed, and only "alpha_prediction_loss" remains.
  2. "model refinement" step with dense layers or "torch.nn.linear" are removed, and only "encorder and decoder" layers remains.
    So, maybe this is a light version of v1.
  3. "generate_trimap(alpha)" function changes a lot, which can generate much large and usual unknown region size and is perhaps a key point for the improvement of the results.

How many epoch does it take to train

The training code occupies three blocks of GPU memory, so that my classmates can't run. I would like to ask how many repochs are needed for training?

can not unpack the 'BEST_checkpoint.tar'

When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

v2 didn't performance well as v1?

Hi,
thanks for your pretrained model!
I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1.
the image:
WechatIMG226
the origin tri map:
test7_tri
the v1 output:
WechatIMG225
the v2 output:
test7_result

do you know what's the problem?

Thanks,

run demo.py question!

File "demo.py", line 84, in
new_bgs = random.sample(new_bgs, 10)
File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample
raise ValueError("Sample larger than population")
ValueError: Sample larger than population

Questions about the PyTorch version and an issue in training regarding to the batch size

Hi,

Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

Traceback (most recent call last):
File "train.py", line 171, in
main()
File "train.py", line 167, in main
train_net(args)
File "train.py", line 64, in train_net
logger=logger)
File "train.py", line 103, in train
alpha_out = model(img) # [N, 3, 320, 320]
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward
up4 = self.up4(up5, indices_4, unpool_shape4)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward
outputs = self.conv(outputs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward
outputs = self.cbr_unit(inputs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward
self.padding, self.dilation, self.groups)
RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

I have tested the DATA PARALLELISM using the example here and it works well.

'train_names.txt' what it is?

After i run pre_process.py successfully, i run train.py but failed. The error shows me that 'FileNotFoundError: [Errno 2] No such file or directory: 'train_names.txt''. i have no this file. What's worse, i do not know anything about this file 'train_names.txt'. What should i do?

Deep-Image-Matting-v2 implemetation on Android

Hi,
Thanks for you work!
its looking awesome output.
I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project?
Can you please give some suggestions?
Thanks in advance.

about the result

Thanks for your great work!
I saw your result is much better than the paper post. Is the fc6 you dropped matters?

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

When i run train.py, I get this error "RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED", What made this error? What should I do? The configuration environment of my computer is ‘python3.5 pytorch1.1.0 cudn9.0 cudnn7.1.3 torchvision0.3.0’

Invalid BEST_checkpoint.tar ?

Hi, thank you for the code.
I tried to download the pretrained model and extract it but it dosnt work.

tar xvf BEST_checkpoint.tar BEST_checkpoint

results in

tar: Ceci ne ressemble pas à une archive de type « tar »
tar: On saute à l'en-tête suivant
tar: BEST_checkpoint : non trouvé dans l'archive
tar: Arrêt avec code d'échec à cause des erreurs précédentes

anything i'm doing the wrong way ? or the provided tar is not valid ?
kind reards

question on alpha as input to gen_trimap

Hi

I have question regarding alpha needed during testing phase. Per my understanding, paper states that for testing/inference only original image & corresponding trimap are needed and encoder/decoder framework will predict the alpha. On the other hand Ground truth alpha (generated by photoshop) is required for training.

I am aware that originally authors calculated alpha manually for all 431 objects and after composing distributed among training and testing sets. So training and testing both gets high quality alpha produced manually. This is fine is testing images of authors are used.

However for the case if I chose another image from internet
Why alpha being passed during testing phase to generate trimap? and how do I get that alpha? Is it the case that mask generated by any segmentation framework can be used as initial/rough alpha as input to generate_trimap and by running through the network will produce the relatively accurate alpha?

Thank you.

Automatic Background Removal technology

I am looking for a deep learning library/sdk which can be used to remove the background from any image automatically (with quality as good as www.remove.bg).

I tried some image segmentation SDKs with pre-trained models such as Tensorflow Lite & Fritz AI, but the accuracy of the cutout mask was very low, amongst other issues.

Criteria :-

  1. Background Removal rather than just Human/Portrait Segmentation

If the foreground consists of person holding a balloon, sittting on a chair, with a pet on his side, then I want all of this to get extracted. Not just the human cutout. The segmentation SDKs I tried are only extracting humans (the chair gets vanished), that too with a very low quality mask (hair gets cut, parts of ear gets cut, etc).

  1. Mask quality should be Super-Accurate

I want even the finer details like the hair, delicate clothes, etc to be extracted perfectly.

  1. Fast & Lightweight (for mobile phone)

I want to use this technology on mobile phones (in an Android app) which should ideally work even in an offline environment. If this option is difficult to achieve, then plan B would be install the technoloy on our server.

  1. Technology
    What technology should I be exploring to achieve this? Is it called image segmentation or the better term would be image matting? (e.g. http://alphamatting.com/eval_25.php)

I have been reading a lot and I am currently lost in the sea of various technologies out there (OpenCV, Deep Matting, Mask RCNN, Instance Segmentation, Detectron2, Tensorflow, Pytorch, etc). I wonder what magic is happening behind the curtains of www.remove.bg

Would this library help me to achieve what I am looking for? Any help you could provide would be awesome.

Thanks a ton!

my own datasets are all full human body images

Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i
process these images,i meet a questions:
When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into model?@foamliu

Pre-Trained model can be reused for "Mobile-Image-Matting"?

Hi,

Will the pretrained model of this "Deep-Image-Matting-PyTorch" project work on the "Mobile-Image-Matting" library too?

In the Mobile-Image-Matting library,

  1. In the Demo section of ReadMe, it gives the following link for the pretrained model, but it gives a 404 error

https://github.com/foamliu/Deep-Mobile-Matting/releases/download/v1.0/BEST_checkpoint.tar

  1. In the Performance section of ReadMe, it gives following link in the table against my-stage0

https://github.com/foamliu/Deep-Image-Matting-v2/releases/download/v1.0/BEST_checkpoint.tar

But the above points to the large checkpoint file from the main "Deep-Image-Matting-v2" project.

Trying to evaluate on this second link give following error

AttributeError: Can't get attribute 'DIMModel' on <module 'models' from '/content/gdrive/My Drive/MobileImageMatting/Mobile-Image-Matting/models/__init__.py'>

Any help?

Demo error

/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes.
warnings.warn(msg, SourceChangeWarning)
Traceback (most recent call last):
File "demo.py", line 69, in
checkpoint = torch.load(checkpoint)
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load
return _load(f, map_location, pickle_module)
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load
result = unpickler.load()
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load
data_type(size), location)
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location
result = fn(storage, location)
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize
device = validate_cuda_device(location)
File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

怎么提取前景呢

您好好@foamliu
MATTING效果非常棒!

我的任务是吧前景提取出来,
我想问下,如何通过您的这个MATTING库提取出前景图像呢?我想应该比较简单,但我对matting不是很了解所以不太会
能否做一个简单的说明,非常感谢!

Question about data(alpha or trimp)

Thank you very much for making this project open source. I have a question and I hope to get your answer.

In the demo.py
We should provide two images:

  1. image(put in data/fg_test),I temporarily call it 1.png
  2. imges's alpha (put in data/mask_test),I temporarily call it 1_alpha.png

Does 1_alpha.png must be obtained manually by PS?
Is there any way to get it automatically?
So we only need to upload a picture of 1.png.

how to Extract training images

7plus@7pluss-MacBook-Pro ~ % python pre_process.py
python: can't open file 'pre_process.py': [Errno 2] No such file or directory

So how to deal with it?

Generate Own BEST_checkpoint.tar

hello,

using your checkpoint for other objects like shoes 👠 , etc it doesnt work and the objects are fully transparent.

i want to train my custom objects from coco dataset or my own set to generate the .tar file.

how can i archive this ?

Automation of Mask Creation ?

how to automate the process of creating mask ,as I see demo rely on masks ? Mask creates using DeepLab are not as refined as in demo.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.