Code Monkey home page Code Monkey logo

pwcnet-tf2's Introduction

Status: the codes implemented by 'tf.compat.v1' still under testing.

pwcnet-tf2

This repository provides the TensorFlow implementation of for paper "PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume", which was presented in CVPR'18 (Oral).

I implemented the latest version "PWCDC-Net" of PWCNet in both TensorFlow 2.0 and TensorFlow 1.0 (based on "tf.compat.v1"), so that you can easily compare the difference between tf 2.x and tf 1.x. The codes for tf 2.x version are modified and inherited from the official pytorch version, while the codes for tf 1.x are implemented in a different way.

Quick Start

Install dependency

The codes are test on Python 3.7. Please run the following script to install the packages.

pip install -r requirements.txt

Download pretrained model

Run the following script to download the provided pretrained model from Google Drive.

./download_models.sh

Or directly get the pretrained model from Google Drive.

Run Inference

Run the following sample command for inference

python inference.py --image_0 sample_images/00003_img1.ppm --image_1 sample_images/00003_img2.ppm \
--ckpt_path ./checkpoints/ckpt-1200000 

Then you will see an output file generated.

Other Usages

TF 2.x

# Training
python train.py --data_dir $DATASET_ROOT_DIR --model_dir ./checkpoints \
--random_scale --random_flip

# Evaluation on FlyingChairs dataset
python evaluate.py --data_dir $DATASET_ROOT_DIR --ckpt_path ./checkpoints/ckpt-1200000 \
--val_list ./list/FlyingChairs_val_list.txt

# Evaluation on MPI Sintel dataset
python evaluate.py --data_dir $DATASET_ROOT_DIR --ckpt_path ./checkpoints/ckpt-1200000 \
--val_list ./list/MPI_Sintel_Clean_train_list.txt

TF 2.x + tf.compat.v1

Have not tested yet.

python train_v1_compat.py --data_dir $DATASET_ROOT_DIR --model_dir ./checkpoints \
--random_scale --random_flip

Results on validation set

The checkpoint is trained on the mixture of FlyingChairs and FlyingThings3D-HalfRes dataset (see the description here). Please note that this checkpoint has not been fine-tuned on MPI Sintel dataset yet.

For your information,

  • Training from scratch with batch size 8 on TitanV takes about 4-5 days.
  • Training from scratch with batch size 32 on V100 takes about 2-3 days.

Average End Point Error (AEPE)

FlyingChairs MPI Sintel Clean (train set) MPI Sintel Final (train set)
1.716 2.852 3.988

Optical Flow Prediction

FlyingChairs

MPI Sintel

Inference time

Dataset TitanV
FlyingChairs (384x512) 0.026 sec
MPI Sintel (436x1024) 0.038 sec

File hierarchy

To use the pre-generated list for training and validation, you need to download and put the datasets as following hierarchy:

- Datasets ($DATASET_ROOT_DIR)
  - FlyingChairs_release
    - data
  - FlyingThings3D
    - frames_cleanpass
    - optical_flow
  - MPI_Sintel
    - training
    - test

You can download the related datasets from

Clarification of implementation details

Mixed-use of 'FlyingChairs' and 'FlyingThings3D-Half-Res' datasets.

"We got more stable results by using a half-resolution version of the FlyingThings3D dataset with an average flow magnitude of 19, much closer to FlyingChairs and MPI-Sintel in that respect. We then trained on a mix of the FlyingChairs and FlyingThings3DHalfRes datasets. This mix, of course, could be extended with additional datasets.", as refered to the explaination by Philferriere.

Thus, in this implementation we trained on the mixed dataset as suggested.

Please see this issue in the official released repository and the explaination here for further details.

We scale the supervision signal at each level, which is different from the official version.

"We scale the ground truth flow by 20 and downsample it to obtain the supervision signals at different levels. Note that we do not further scale the supervision signal at each level.", as refered to Sec. 4 in the original paper.

I have tried to implement the loss function as decribed above, however, it seems that the training loss cannot converge. Instead, I refer to this implementation for the design of loss function: (1) rescale the supervision signals and (2) still scale the estimated optical flow before warping function. Please note that the flow scale in each level can be learned and adjusted by transposed layers.

Cannot always converge with 'FlyingChairs' and 'FlyingThings3D' datasets.

"However, we observe in our experiments that a deeper optical flow estimator mightget stuck at poor local minima, which can be detected by checking the validation errors after a few thousand iterations and fixed by running from a different random initialization.", as refered to Sec. 4-2 in the original paper.

I also encountered this problem for several times. The solution is to check the validation AEPE to see if the AEPE is decreasing or not. If not, you need to restart the whole training process. In general, the problem can be detected before 8000 steps.

Acknowledment & Reference

Citation

@InProceedings{Sun2018PWC-Net,
  author    = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz},
  title     = {{PWC-Net}: {CNNs} for Optical Flow Using Pyramid, Warping, and Cost Volume},
  booktitle = CVPR,
  year      = {2018},
}

If you find this implementation or the pre-trained models helpful, please consider to cite:

@misc{Yang2020,
  author = {Hsuan-Kung, Yang},
  title = {PWCNet-tf2},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/hellochick/PWCNet-tf2}}
}

pwcnet-tf2's People

Contributors

hellochick avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pwcnet-tf2's Issues

Backward Pass of Warping layer

Hi, thanks for your for implementing and sharing. I am wondering how the backward pass of warping layer working. As far as I am concerned, Floor() operation is nondifferentiable. Do we need to use gradient_override_map to substitute gradient of Identity op for Floor ?

List of requirements

Hello, first I would like to thank you for sharing the code.
I would like to ask you if you have the list of the requirements to install in a environment to run this project.
Thanks

Cannot download checkpoint from git-lfs: This repository is over its data quota.

Hi there,
Thanks for your implementation of PWC. Unfortunately, I'm not able to download the checkpoint from git-lfs since the data quota is over for your repository.

➜  PWCNet-tf2 git:(master) ✗ git lfs pull
Git LFS: (0 of 1 files) 0 B / 107.28 MB                                        
batch response: This repository is over its data quota. Account responsible for LFS bandwidth should purchase more data packs to restore access.
error: failed to fetch some objects from 'https://github.com/hellochick/PWCNet-tf2.git/info/lfs'

Would it be possible to upload the checkpoint somewhere else? I Would really appreciate it!
Thanks in advance!

Unable to run inference.py

I get the following error when running inference.py:

tensorflow/core/framework/op_kernel.cc:1763] OP_REQUIRES failed at save_restore_v2_ops.cc:205 : Out of range: Read less bytes than requested
Seems to be something wrong with the checkpoint files?

Writing custom data generator

Hi All,

Thank you for sharing work. The code is perfectly working and I was able to train and run the inference.
But I have one simple quesiton,
I don't know if it is the right platform for this query but it will be great help if you could respond to it.

I was customizing data generator script but then the model fails to converge. Any idea what could be the reason

Here is my code,
`

            from flowUtils import read_flow
            import tensorflow as tf
            from tensorflow.keras.utils import Sequence
            import numpy as np
            import cv2
          class DataGenerator(Sequence):
              def __init__(self, im1PairsList, im2PairsList, flowList, batch_size=6, crop_size=[256,448], shuffle=True, isTrain=False) -> None:
                  self.im1PairsList = im1PairsList
                  self.im2PairsList = im2PairsList
                  self.flowList = flowList
                  self.batch_size = batch_size
                  self.crop_size = crop_size
                  self.shuffle = shuffle
                  self.isTrain = isTrain
                  self.on_epoch_end()
              
              def __len__(self):
                  return int(np.floor(len(self.im1PairsList)/ self.batch_size))
              
              def __getitem__(self, index):
                  indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
                  img1PairsTemp = [self.im1PairsList[k] for k in indexes]
                  img2PairsTemp = [self.im2PairsList[k] for k in indexes]
                  flowPairsTemp = [self.flowList[k] for k in indexes]
                  imPair, flow = self._data_generation(img1PairsTemp, img2PairsTemp, flowPairsTemp)
                  return imPair, flow
          
              def normalizeImages(self, img):
                  return np.asarray(img/255., dtype=np.float32)
          
              def tf_image_crop(self, img_concat):
                  im_cropped = tf.image.random_crop(img_concat, [self.crop_size[0], self.crop_size[1], 8]) # RGB + RGB + UV = 8 channels
                  im1 = im_cropped[:, :, :3]
                  im2 = im_cropped[:, :, 3:6]
                  flo = im_cropped[:, :, 6:]
                  return im1, im2, flo
          
              def _data_generation(self, img1PairsTemp, img2PairsTemp, flowPairsTemp):
                  imgPair = []
                  flowPair = []
                  for img1Path, img2Path, flowPath in zip(img1PairsTemp, img2PairsTemp, flowPairsTemp):
                      im1 = cv2.imread(img1Path)
                      im2 = cv2.imread(img2Path)
                      flo = read_flow(flowPath)
                      norm_im1 = self.normalizeImages(im1)
                      norm_im2 = self.normalizeImages(im2)
                      norm_im1_tf = tf.convert_to_tensor(norm_im1, dtype=tf.float32) 
                      norm_im2_tf = tf.convert_to_tensor(norm_im2, dtype=tf.float32) 
                      im_concat = tf.concat([norm_im1_tf, norm_im2_tf, flo], axis=2)
                      
                      if self.isTrain:
                          im1, im2, flo = self.tf_image_crop(im_concat)
                          imgconc = tf.concat([im1, im2], axis=2)
                          imgPair.append(np.expand_dims(imgconc, axis=0))
                          flowPair.append(np.expand_dims(flo, axis=0))
                      else:
                          imgconc = tf.concat([norm_im1_tf, norm_im2_tf], axis=2)
                          imgPair.append(np.expand_dims(imgconc, axis=0))
                          flowPair.append(np.expand_dims(flo, axis=0))
          
                  imgPair = np.concatenate(imgPair, axis=0)
                  flowPair = np.concatenate(flowPair, axis=0)
                  return tf.convert_to_tensor(imgPair,dtype=tf.float32), tf.convert_to_tensor(flowPair, dtype=tf.float32)
              
              def on_epoch_end(self):
                  self.indexes = np.arange(len(self.im1PairsList))
                  if self.shuffle==True:
                      np.random.shuffle(self.indexes)

`

Training with only FlyingChairs dataset.

As the optical_flow folder of FlyingThings3D is too large to download, I try to train PWCNet-tf2 with only FlyingChairs dataset. Do you have any suggestions for parameter settings? Will it hurt the performance in a big way?

Thanks for your help.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.