Code Monkey home page Code Monkey logo

dhyuan99 / veckm Goto Github PK

View Code? Open in Web Editor NEW
19.0 4.0 2.0 6.66 MB

Official GitHub repo for VecKM. A very efficient and descriptive local geometry encoder / point tokenizer / patch embedder. ICML2024.

Home Page: http://arxiv.org/abs/2404.01568

License: MIT License

Python 97.31% C++ 1.09% Cuda 1.14% C 0.46%
efficient geometry icml point-cloud real-time patch-encoder point-embedding point-tokenizer modelnet40 normal-estimation point-cloud-library point-cloud-processing shapenet

veckm's Introduction

VecKM: A Linear Time and Space Local Point Cloud Geometry Encoder

Dehao Yuan ,  Cornelia Fermüller ,  Tahseen Rabbani ,  Furong Huang ,  Yiannis Aloimonos  

ICML2024      [arXiv]

Highlighted Features

Usage

ℹ️ This section is illustrated with an examples/example.ipynb.

⚠️ VecKM is sensitive to scaling. Please make sure to scale your data so that your local point cloud lies within a UNIT BALL with radius 1.

⚠️ For example, if you have a point cloud pts and you want to consider the local geometry with radius 0.1. Then you will do pts *= 10 so that now you are considering the local geometry with radius 1.

⚠️ If your x, y, z do not have the same scale, make sure scaling them so that they have the same scale.

⚠️ VecKM is not rotational invariant. If the local point cloud is rotated, the encoding can be very different.

It is very simple to implement VecKM if you want to incorporate it into your own code. Suppose your input point cloud pts has shape (n,3) or (b,n,3), then the following code will give you the VecKM local geometry encoding with output shape (n,d) or (b,n,d). It is recommended to have PyTorch >= 1.13.0 since it has better support for complex tensors, but lower versions shall also work.

pip install scipy
pip install complexPyTorch
import torch
import torch.nn as nn
import numpy as np
from scipy.stats import norm

def strict_standard_normal(d):
    # this function generate very similar outcomes as torch.randn(d)
    # but the numbers are strictly standard normal, no randomness.
    y = np.linspace(0, 1, d+2)
    x = norm.ppf(y)[1:-1]
    np.random.shuffle(x)
    x = torch.tensor(x).float()
    return x

class VecKM(nn.Module):
    def __init__(self, d=256, alpha=6, beta=1.8, p=4096):
        """ I tested empirically, here are some general suggestions for selecting parameters d and p:
        (alpha=6, beta=1.8) works for the data scale that your neighbordhood radius = 1.
        Please ensure your point cloud is appropriately scaled!
        d = 256, p = 4096 is for point cloud size ~20k. Runtime is about 28ms.
        d = 128, p = 8192 is for point cloud size ~50k. Runtime is about 76ms.
        For larger point cloud size, please enlarge p, but if that costs too much, please reduce d.
        A general empirical phenomenon is (d*p) is postively correlated with the encoding quality.

        For the selection of parameter alpha and beta, please see the github section below.
        """
        super().__init__()
        self.alpha, self.beta, self.d, self.p = alpha, beta, d, p
        self.sqrt_d = d ** 0.5

        self.A = torch.stack(
            [strict_standard_normal(d) for _ in range(3)], 
            dim=0
        ) * alpha
        self.A = nn.Parameter(self.A, False)                                    # (3, d)

        self.B = torch.stack(
            [strict_standard_normal(p) for _ in range(3)], 
            dim=0
        ) * beta
        self.B = nn.Parameter(self.B, False)                                    # (3, d)

    def forward(self, pts):
        """ Compute the dense local geometry encodings of the given point cloud.
        Args:
            pts: (bs, n, 3) or (n, 3) tensor, the input point cloud.

        Returns:
            G: (bs, n, d) or (n, d) tensor
               the dense local geometry encodings. 
               note: it is complex valued. 
        """
        pA = pts @ self.A                                                       # Real(..., n, d)
        pB = pts @ self.B                                                       # Real(..., n, p)
        eA = torch.concatenate((torch.cos(pA), torch.sin(pA)), dim=-1)          # Real(..., n, 2d)
        eB = torch.concatenate((torch.cos(pB), torch.sin(pB)), dim=-1)          # Real(..., n, 2p)
        G = torch.matmul(
            eB,                                                                 # Real(..., n, 2p)
            eB.transpose(-1,-2) @ eA                                            # Real(..., 2p, 2d)
        )                                                                       # Real(..., n, 2d)
        G = torch.complex(
            G[..., :self.d], G[..., self.d:]
        ) / torch.complex(
            eA[..., :self.d], eA[..., self.d:]
        )                                                                       # Complex(..., n, d)
        G = G / torch.norm(G, dim=-1, keepdim=True) * self.sqrt_d
        return G

vkm = VecKM()
pts = torch.rand((10,1000,3))
print(vkm(pts).shape) # it will be Complex(10,1000,256)
pts = torch.rand((1000,3))
print(vkm(pts).shape) # it will be Complex(1000, 256)

from complexPyTorch.complexLayers import ComplexLinear, ComplexReLU
# You may want to use apply two-layer feature transform to the encoding.
feat_trans = nn.Sequential(
    ComplexLinear(256, 128),
    ComplexReLU(),
    ComplexLinear(128, 128)
)
G = feat_trans(vkm(pts))
G = G.real**2 + G.imag**2 # it will be Real(10, 1000, 128) or Real(1000, 128).

ℹ️ See [Suggestion for Tuning $\alpha$, $\beta$] for how to tune alpha and beta parameters.

ℹ️ See [Suggestion for Tuning $d$, $p$] for how to tune d and p parameters.

ℹ️ Feel free to contact me if you are unsure! I will try to respond within 1 day.

Suggestions for picking $\alpha$ and $\beta$

There are two parameters alpha and beta in the VecKM encoding. They are controlling the resolution and receptive field of VecKM, respectively. A higher alpha will produce a more detailed encoding of the local geometry, and a smaller alpha will produce a more abstract encoding. A higher beta will result in a smaller receptive field. You could look at the figure below for a rough understanding.

  • You can slightly increase alpha if you have a relatively dense point cloud and want high-frequency details.
  • You can slightly decrease alpha if you want to smooth out the high-frequency details and only keep the low-frequency components.
  • For beta, it is closely related to the neighborhood radius. We provide a table of the correspondence. For example, if you want to extract the local geometry encoding with radius 0.3, then you would select beta to be 6.
beta 1 2 3 4 5 6 7 8 9 10
radius 1.800 0.900 0.600 0.450 0.360 0.300 0.257 0.225 0.200 0.180
beta 11 12 13 14 15 16 17 18 19 20
radius 0.163 0.150 0.138 0.129 0.120 0.113 0.106 0.100 0.095 0.090
beta 21 22 23 24 25 26 27 28 29 30
radius 0.086 0.082 0.078 0.075 0.072 0.069 0.067 0.065 0.062 0.060

Suggestion for picking $d$ and $p$

We find empirically $d\times p$ is strongly correlated to the encoding quality. Here are several tips:

  • A larger local neighborhood requires a larger $d$.
  • A larger point cloud size requires a larger $p$.

Several examples:

  • d = 256, p = 4096 is for point cloud size ~100k. Runtime is about 80ms.

Experiments

Check out the applications of VecKM to normal estimation, classification, part segmentation. The overall architecture change will be like:

Citation

If you find it helpful, please consider citing our papers:

@misc{yuan2024linear,
      title={A Linear Time and Space Local Point Cloud Geometry Encoder via Vectorized Kernel Mixture (VecKM)}, 
      author={Dehao Yuan and Cornelia Fermüller and Tahseen Rabbani and Furong Huang and Yiannis Aloimonos},
      year={2024},
      eprint={2404.01568},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

veckm's People

Contributors

dhyuan99 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

veckm's Issues

Runtime Error for 'pts = torch.rand((10,1000,3))' case


RuntimeError Traceback (most recent call last)
in <cell line: 69>()
67 vkm = VecKM()
68 pts = torch.rand((10,1000,3))
---> 69 print(vkm(pts).shape) # it will be Complex(10,1000,256)
70 pts = torch.rand((1000,3))
71 print(vkm(pts).shape) # it will be Complex(1000, 256)

2 frames
in forward(self, pts)
57 eB.transpose(-1,-2) @ eA # Real(..., 2p, 2d)
58 ) # Real(..., n, 2d)
---> 59 G = torch.complex(
60 G[:,:self.d], G[:,self.d:]
61 ) / torch.complex(

RuntimeError: The size of tensor a (256) must match the size of tensor b (1744) at non-singleton dimension 1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.