Code Monkey home page Code Monkey logo

deep-image-matting's Introduction

This is an implementation for paper Deep Image Matting

  • Deep image matting is a learning method to estimate the alpha matting params for fg,bg,merged and trimap
  • 该项目基于pytorch实现,主要的数据,模型,损失函数,训练策略介绍如下:
    -- data/py_adobe_data.py .the online fg/bg alpha merge data,compose COCO 2014 train and Matting Datasets
    -- models/py_encoder_decoder.py. The model define ,vgg encoder and unpooling/conv decoder.
    -- train_encoder_decoder.py. The train stage define, encoder-decoder/refine-head/over-all,totally three stages.
    -- utils/visulization.py. The loss and image vis module

项目简介

-- 数据集,数据集使用在线合成的方法。具体存放路径如下所示:
需要修改 data/py_adobe_data.py 中数据位置 1.将CombineDataset的前景数据和背景数据文件夹拷贝到一起。

            self.a_path = './data/adobe_data/trainval/alpha' #alpha 存放路径,将Train数据的alpha和Others数据的alpha收集一起存放这里
            self.fg_path = './data/adobe_data/trainval/fg/' #同上, 存储前景数据,共439张
            self.bg_path = '/data/jh/notebooks/hehao/datasets/coco/train2014/' #the coco path 指向coco数据集地址
            

单模块测试

数据模块

数据模块的功能性测试,可单独测试

DeepImageMatting$ python data/py_adobe_data.py
        data[:,:,0:3] = image # firset three rgb channel ,经过归一化的数据sub-mean,div-std
        data[:,:,3] = torch.tensor(trimap)  # last channel is trimap,value(0,128,255) 255为前景

        label[:,:,0:3] = torch.tensor(bg)   #前景,(0,255)之间取值
        label[:,:,3:6] = torch.tensor(fg)   #背景,(0,255)之间
        label[:,:,6:9] = torch.tensor(merged) #合成图(0,255),该值和data[0:3]一样,只不过没经过归一化,为便于计算loss缘故
        label[:,:,9:10] = torch.tensor(alpha.reshape(self.size,self.size,1)) #ground_truth alpha matting值
        label[:,:,10] = torch.tensor(mask) #unknown region区域的掩码
        
模型模块

模型部分包括encoder-decoder,encoder-decoder-refinehead 残差结构,都可以使用以下语句测试

DeepImageMatting$ python models/py_encoder_decoder.py #测试模型计算
论文种对于Unpooling 和Deconv介绍不是很清晰,针对SegNet网络种的介绍,可以在vgg maxpooling时保留Max的位置索引,在上采样时进行赋值。  
上采样有若干种方法:双线性插值(可学习),双线性插值(不可学习),转置卷积(反卷积,stride大于1),reverse maxpooling等方式,当前实现种采用反向maxpooling的方式
模型与训练参数加载

matting网络的encoder部分采用的vgg_bn16的参数,需要将对应参数灌入当前模型,具体参见models/py_encoder_decoder.py 中load_vggbn 函数

损失函数

损失函数分为两种,参考论文中提出的alpha-prediction loss 和 Compositional loss 计算方式可参见内部loss类型

DeepImageMatting$ python models/py_loss.py #测试损失函数
对于论文中提出的损失函数有两个:  
- alpha 损失,预测的matting参数和groud_truth alpha之间的差值平方均值 mse。 具体参见 AlphaPredLoss  
- compose损失,使用matting预测参数合成的图片和原本真实合成图片的mse差值 。具体参见 ComposeLoss
训练模块

所有的三个阶段的训练模块都集中在自定义类Trainer中,分为初始化:设置数据模型损失函数训练策略,训练,验证主要的模块

DeepImageMatting$ python train_encoder_decoder.py --params  
其中stage为阶段参数: 
- 第一阶段:训练encoder-decoder 类似于SegNet结构
- 第二阶段:训练refine_head,一个alpha细化模块 。
- 第三阶段:整体进行训练,整体结构类似一个残差块。

训练过程记录:

Doing list 需要实验的任务

  • 在第一阶段encoder-decoder训练时,因为alpha prediction loss和compositional loss 的数量级不一样,采用文中作者提出的两种loss加起来, 会导致无法收敛,一直震荡的状态,通过观察梯度发现梯度很小,根据论文中实验部分描述,采用alpha-prediction loss来train,会收敛。
    avatar
    avatar avatar

  • 训练过程中可视化 前景背景,合成图和alpha数值

  • 需要尝试双线性插值和unpooling对训练的影响
    medium 博客,unpooling和deconv unpooling介绍

    unpooling deconv 的使用在SegNet等网络中
  • 需要训练encoder-decoder中间衔接的trans模块,卷积核大小,当前采用一个3x3和1x1(具体参见py_encoder_decoder.py的trans部分)

    class transMap(nn.Module):
        def __init__(self):
            super(transMap,self).__init__()
            self.conv1 = conv2DBatchNormRelu(512,512,3,1,1)
            self.conv2 = conv2DBatchNormRelu(512,512,1,1,0)
    
        def forward(self, x):
            return self.conv2(self.conv1(x))
  • encoder的输入为合成图RGB和trimap通道,RGB通道按照之前totensor和normalize操作, trimap在当前数据集为(0,128,255)三个取值,其数值量级不一致,需要对trimap采用归一化的方法,当前采用 除以255,减去0.5均值,除以0.5的方式

  • 需要尝试使用skip-connection之后的区别
    结合U-Net和SegNet等分割网络中跳层连接的形式,结合encoder浅层的语义,提升matting效果

  • 需要结合最新的BiSeNet等分割网络的技巧来提升matting效果

Note

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.