Code Monkey home page Code Monkey logo

dqtorch's Introduction

dqtorch

Introduction

dqtorch is a pytorch libarary for fast batched (dual) quaternion operations. The speed improvement over native pytorch implementation (e.g. Pytorch3D) is achieved through highly optimized CUDA extensions.

Performance

  • batch size = 4096*128 (for the use case of training deformable NeRF, we typically have 4096 rays with 128 points sampled per ray).
  • tested on a single RTX3080 GPU.
quaternion_mul quaternion_conjugate
native pytorch 561.3 us 66.5 us
dqtorch 33.0 us 22.5 us

Key Features

  • fast CUDA implementation for quaternion operations.
  • basic operations for SO(3) and SE3(3) transformation using (dual) quaternions.
  • supports all torch.half, torch.float, torch.double tensors.
  • supports gradient of a gradient.

to do:

  • dual quaternion blend skinning.

Get Started

Requirments

tested in Pytorch 1.12, CUDA-11.1, gcc-6.3.0.

To do: check if it compiles with:

  • Pytorch 2.0
  • CUDA 11.7

Install

python setup.py install

Test

python examples.py

Tutorial

Basic quaternion operations

import dqtorch
import torch

n_pts = 4096*128

# get a normalized quaternion from a batch of random axis angles
qr1 = dqtorch.axis_angle_to_quaternion(torch.randn(n_pts, 3).cuda()) 
qr2 = dqtorch.axis_angle_to_quaternion(torch.randn(n_pts, 3).cuda())

# quaternion multiplication
qr3 = dqtorch.quaternion_mul(qr1, qr2)
# if the number of channels is 3, we assume the quaternion real part is 0
qr4 = dqtorch.quaternion_mul(qr1, qr2[..., 1:])

# apply rotation to 3d points
p1 = torch.randn(n_pts, 3).cuda()
p2 = dqtorch.quaternion_apply(qr1, p1)

# quaternion conjugate
inv_qr1 = dqtorch.quaternion_conjugate(qr1)
# apply inverse transform to p2, and we should get back p1
p1_by_inv = dqtorch.quaternion_apply(inv_qr1, p2)
print((p1_by_inv-p1).abs().max()) # should be close to 0

SE(3) transformation by quaternion + translation

# se3 representation by quaternion + translation
t1 = torch.randn(n_pts, 3).cuda() # create random translations
# apply se3 transformation to points
p2 = dqtorch.quaternion_translation_apply(qr1, t1, p1)
# inverse of se3 transformation
qr1_inv, t1_inv = dqtorch.quaternion_translation_inverse(qr1, t1)
# compose two se3 transformation
qr3, t3 = dqtorch.quaternion_translation_compose((qr1_inv, t1_inv), (qr1, t1))
print((qr3[..., 0]-1).abs().max(), qr3[..., 1:].abs().max(), t3.abs().max()) # should be close to 0

SE(3) transformation by dual quaternion

# se3 representation by dual quaternions, which is stored as a tuple of two tensors
dq1 = dqtorch.quaternion_translation_to_dual_quaternion(qr1, t1)
dq1_inv = dqtorch.quaternion_translation_to_dual_quaternion(qr1_inv, t1_inv)
# compose two se3 transformation
dq3 = dqtorch.dual_quaternion_mul(dq1_inv, dq1)
print((dq3[0][..., 0]-1).abs().max(), dq3[0][..., 1:].abs().max(), dq3[1].abs().max()) # should be close to 0
# apply se3 transformation to points
p2 = dqtorch.dual_quaternion_apply(dq1, p1)
# dual quaternion inverse
dq3 = dqtorch.dual_quaternion_inverse(dq1)
print((dq3[0]-dq1_inv[0]).abs().max(), (dq3[1]-dq1_inv[1]).abs().max()) # should be close to 0
# convert from dual quaternion to quaternion translation
qr3, t3 = dqtorch.dual_quaternion_to_quaternion_translation(dq1)
print((qr3-qr1).abs().max(), (t3-t1).abs().max()) # should be close to 0

Related Projects

dqtorch has been used in our research of deformable NeRF:

dqtorch's People

Contributors

mightychaos avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

dqtorch's Issues

Potential Compilation Error when building with PyTorch 2.0.0+

If by any chance one sees
#error C++17 or later compatible compiler is required to use PyTorch.

during python setup.py install

This is an issue caused by using PyTorch version >= 2.0.0, similar to this post.

I fixed the issue by changing the top lines of setup.py into

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from pkg_resources import DistributionNotFound, get_distribution, parse_version
import torch
_src_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dqtorch')

use_new_std = parse_version(torch.__version__) >= parse_version('2.0.0')

nvcc_flags = [
    '-O3', '-std=c++14' if not use_new_std else '-std=c++17',
    '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
]

if os.name == "posix":
    c_flags = ['-O3', '-std=c++14' if not use_new_std else '-std=c++17']
elif os.name == "nt":
...

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.