Code Monkey home page Code Monkey logo

shapegf's Introduction

Learning Gradient Fields for Shape Generation

This repository contains a PyTorch implementation of the paper:

Learning Gradient Fields for Shape Generation [Project page] [Arxiv] [Short-video] [Long-video]

Ruojin Cai*, Guandao Yang*, Hadar Averbuch-Elor, Zekun Hao, Serge Belongie, Noah Snavely, Bharath Hariharan (* Equal contribution)

ECCV 2020 (Spotlight)

Introduction

In this work, we propose a novel technique to generate shapes from point cloud data. A point cloud can be viewed as samples from a distribution of 3D points whose density is concentrated near the surface of the shape. Point cloud generation thus amounts to moving randomly sampled points to high-density areas. We generate point clouds by performing stochastic gradient ascent on an unnormalized probability density, thereby moving sampled points toward the high-likelihood regions. Our model directly predicts the gradient of the log density field and can be trained with a simple objective adapted from score-based generative models. We show that our method can reach state-of-the-art performance for point cloud auto-encoding and generation, while also allowing for extraction of a high-quality implicit surface.

Dependencies

# Create conda environment with torch 1.2.0 and CUDA 10.0
conda env create -f environment.yml
conda activate ShapeGF

# Compile the evaluation metrics
cd evaluation/pytorch_structural_losses/
make clean
make all

Dataset

Please follow the instruction from PointFlow to set-up the dataset: link.

Pretrained Model

Pretrained model will be available in the following google drive: link. To use the pretrained models, download the pretrained folder and put it under the project root directory.

Testing the pretrained auto-encoding model:

The following commands test the performance of the pre-trained models in the point cloud auto-encoding task. The commands output the CD and EMD on the test/validation sets.

# Usage:
# python test.py <config> --pretrained <checkpoint_filename>

python test.py configs/recon/airplane/airplane_recon_add.yaml \
    --pretrained pretrained/recon/airplane_recon_add.pt
python test.py configs/recon/car/car_recon_add.yaml \
    --pretrained pretrained/recon/car_recon_add.pt
python test.py configs/recon/chair/chair_recon_add.yaml \
    --pretrained pretrained/recon/chair_recon_add.pt

The pretrained model's auto-encoding performance is as follows:

Dataset Metrics Ours Oracle
Airplane CD x1e4 0.966 0.837
EMD x1e2 2.632 2.062
Chair CD x1e4 5.660 3.201
EMD x1e2 4.976 3.297
Car CD x1e4 5.306 3.904
EMD x1e2 4.380 3.251

Testing the pretrained generation model:

The following commands test the performance of the pre-trained models in the point cloud generation task. The commands output the JSD, MMD-(CD/EMD), COV-(CD/EMD), and 1NN-(CD/EMD).

# Usage:
# python test.py <config> --pretrained <checkpoint_filename>

python test.py configs/gen/airplane_gen_add.yaml \
    --pretrained pretrained/gen/airplane_gen_add.pt
python test.py configs/gen/car/car_gen_add.yaml \
    --pretrained pretrained/gen/car_gen_add.pt
python test.py configs/gen/chair/chair_gen_add.yaml \
    --pretrained pretrained/gen/chair_gen_add.pt

Training

Single GPU Training

# Usage:
python train.py <config>

Multi GPU Training

Our code also provides single-node multi GPU training using pytorch's Distributed Data Parallel. The script will run on all GPUs visible to the function. The usage and examples are as follows:

# Usage
python train_multi_gpus.py <config> 

# To specify the total batch size, use --batch_size
python train_multi_gpus.py <config> --batch_size <#gpu x batch_size/GPU>

Stage-1: Auto-encoding

In this stage, we create a conditional generator that models the distribution of 3D points conditioned on the latent vector. The commands used to train our auto-encoding model for a single-shape, single ShapeNet category, and the whole ShapeNet are:

# Single shape
python train.py configs/recon/single_shapes/dress.yaml  # the dress in the teaser
python train.py configs/recon/single_shapes/torus.yaml  # the torus in the teaser

# Single category
python train.py configs/recon/airplane/airplane_recon_add.yaml  # airplane
python train.py configs/recon/airplane/chair_recon_add.yaml     # chair
python train.py configs/recon/airplane/car_recon_add.yaml       # car 

# Whole shape-net
python train_multi_gpus.py configs/recon/shapenet/shapenet_recon.yaml  # ShapeNet

Stage-2: Generation

In the second stage, we train a l-GAN to model the distribution of shapes - which are captured by the latent vector of the auto-encoder described in the first stage. The commands used to train l-GAN for a single ShapeNet category using the default pretrained model (in the <root>/pretrained directory) are:

python train.py configs/gen/airplane_gen_add.yaml  # airplane
python train.py configs/gen/chair_gen_add.yaml     # chair
python train.py configs/gen/car_gen_add.yaml       # car 

Cite

Please cite our work if you find it useful:

@inproceedings{ShapeGF,
 title={Learning Gradient Fields for Shape Generation},
 author={Cai, Ruojin and Yang, Guandao and Averbuch-Elor, Hadar and Hao, Zekun and Belongie, Serge and Snavely, Noah and Hariharan, Bharath},
 booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
 year={2020}
}

Acknowledgment

This work was supported in part by grants from Magic Leap and Facebook AI, and the Zuckerman STEM leadership program.

shapegf's People

Contributors

hadarelor avatar ruojincai avatar stevenygd avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.