Code Monkey home page Code Monkey logo

torch-cubic-spline-grids's Introduction

torch-cubic-spline-grids

License PyPI Python Version CI codecov

Cubic spline interpolation on multidimensional grids in PyTorch.

The primary goal of this package is to provide learnable, continuous parametrisations of 1-4D spaces.


Overview

torch_cubic_spline_grids provides a set of PyTorch components called grids.

Grids are defined by

  • their dimensionality (1d, 2d, 3d, 4d...)
  • the number of points covering each dimension (resolution)
  • the number of values stored on each grid point (n_channels)
  • how we interpolate between values on grid points

All grids in this package consist of uniformly spaced points covering the full extent of each dimension.

First steps

Let's make a simple 2D grid with one value on each grid point.

import torch
from torch_cubic_spline_grids import CubicBSplineGrid2d

grid = CubicBSplineGrid2d(resolution=(5, 3), n_channels=1)
  • grid.ndim is 2
  • grid.resolution is (5, 3) (or (h, w))
  • grid.n_channels is 1
  • grid.data.shape is (1, 5, 3) (or (c, h, w))

In words, the grid extends over two dimensions (h, w) with 5 points in h and 3 points in w. There is one value stored at each point on the 2D grid. The grid data is stored in a tensor of shape (c, *grid_resolution).

We can obtain the value (interpolant) at any continuous point on the grid. The grid coordinate system extends from [0, 1] along each grid dimension. The interpolant is obtained by sequential application of cubic spline interpolation along each dimension of the grid.

coords = torch.rand(size=(10, 2))  # values in [0, 1]
interpolants = grid(coords)
  • interpolants.shape is (10, 1)

Optimisation

Values at each grid point can be optimised by minimising a loss function associated with grid interpolants. In this way the continuous space of the grid can be made to more accurately model a 1-4D space.

The image above shows the values of 6 control points on a 1D grid being optimised such that interpolating between them with cubic B-spline interpolation approximates a single oscillation of a sine wave.

Notebooks are available for this 1D example and a similar 2D example.

Types of grids

torch_cubic_spline_grids provides grids which can be interpolated with cubic B-spline interpolation or cubic Catmull-Rom spline interpolation.

spline continuity interpolating?
cubic B-spline C2 No
Catmull-Rom spline C1 Yes

If your need the resulting curve to intersect the data on the grid you should use the cubic Catmull-Rom spline grids

  • CubicCatmullRomGrid1d
  • CubicCatmullRomGrid2d
  • CubicCatmullRomGrid3d
  • CubicCatmullRomGrid4d

If you require continuous second derivatives then the cubic B-spline grids are more suitable.

  • CubicBSplineGrid1d
  • CubicBSplineGrid2d
  • CubicBSplineGrid3d
  • CubicBSplineGrid4d

Regularisation

The number of points in each dimension should be chosen such that interpolating on the grid can approximate the underlying phenomenon being modelled without overfitting. A low resolution grid provides a regularising effect by smoothing the model.

Installation

torch_cubic_spline_grids is available on PyPI

pip install torch-cubic-spline-grids

Related work

This is a PyTorch implementation of the way Warp models continuous deformation fields and locally variable optical parameters in cryo-EM images. The approach is described in Dimitry Tegunov's paper:

Many methods in Warp are based on a continuous parametrization of 1- to 3-dimensional spaces. This parameterization is achieved by spline interpolation between points on a coarse, uniform grid, which is computationally efficient. A grid extends over the entirety of each dimension that needs to be modeled. The grid resolution is defined by the number of control points in each dimension and is scaled according to physical constraints (for example, the number of frames or pixels) and available signal. The latter provides regularization to prevent overfitting of sparse data with too many parameters. When a parameter described by the grid is retrieved for a point in space (and time), for example for a particle (frame), B-spline interpolation is performed at that point on the grid. To fit a grid’s parameters, in general, a cost function associated with the interpolants at specific positions on the grid is optimized.


For a fantastic introduction to splines I recommend Freya Holmer's YouTube video.

The Continuity of Splines - YouTube

torch-cubic-spline-grids's People

Contributors

alisterburt avatar dependabot[bot] avatar

Stargazers

Evan Widloski avatar RainySong.Bai avatar Utz Ermel avatar  avatar Marten Chaillet avatar Tim Kircher avatar  avatar fred monroe avatar Alexander März avatar dengbuqi avatar  avatar Kaihua Zhang avatar Chenhao Shuai avatar Karl Hahn avatar SubChange avatar Xhark avatar

Watchers

James Cloos avatar  avatar  avatar Kostas Georgiou avatar SubChange avatar

Forkers

evidlo

torch-cubic-spline-grids's Issues

Problem with using CubicBSplineGrid2d to approximate RectBivariateSpline

  • torch-cubic-b-spline-grid version: 0.0.8
  • Python version: 3.8.10
  • Operating System: linux

Description

I want to use CubicBSplineGrid2d to approximate the results of RectBivariateSpline, and then apply CubicBSplineGrid2d to torch model building.
However, I was not successful (the following image is the result)😭, is it because CubicBSplineGrid2d cannot handle such steep images? Can you help me figure out the reason?🌸
It seems that the fitting only occurs in a small part of the image.

d792af04-8a33-47f3-9482-3b57cc8d4ba0

What I Did

Here is the code and data:
data: https://drive.google.com/file/d/18F2XR2HaIXrYxY-ibmOr9y3KMEt2dHW-/view?usp=sharing

import torch
import numpy as np
from torch_cubic_spline_grids import CubicCatmullRomGrid2d, CubicBSplineGrid2d
import matplotlib.pyplot as plt
from scipy import interpolate

def get_inter(uvp):
    x = np.concatenate([xyt[:16, 1]-2*np.pi, xyt[:16, 1], xyt[:16, 1]+2*np.pi])
    y = np.concatenate([xyt[:16, 1]-2*np.pi, xyt[:16, 1], xyt[:16, 1]+2*np.pi])
    uvp = uvp.reshape(16,16)
    uvp = np.concatenate([uvp,uvp,uvp],axis=0)
    uvp = np.concatenate([uvp,uvp,uvp],axis=1)
    inter = interpolate.RectBivariateSpline(x,y,uvp,kx=4,ky=4)
    return inter

def make_observations(n, inter):
    x = (np.random.random((n,2))*2-1)*np.pi
    y = inter(x[:,0], x[:,1], grid=False)
    x = torch.from_numpy(x)
    y = torch.from_numpy(y)
    return x, y

npz = np.load('test.npz')
xyt = npz['xyt']
uvp = npz['uvp']

inter = get_inter(uvp)

grid_2d = CubicBSplineGrid2d(resolution=(512,512), n_channels=1)
optimiser = torch.optim.Adam(grid_2d.parameters(), lr=0.001)
for i in range(1000):
    # make (noisy) observations of the data we want to model
    # what does the model predict for our observations?
    item_x,item_y = make_observations(1024, inter)
    prediction = grid_2d(item_x).squeeze()

    # zero gradients and calculate loss between observations and model prediction
    optimiser.zero_grad()
    loss = torch.sqrt(torch.mean((prediction - item_y)**2)) / torch.sqrt(torch.mean(item_y**2))
    # backpropagate loss and update values at points on grid
    loss.backward()
    optimiser.step()
    print(loss.item())


d = 2*np.pi/512
x = np.linspace(-np.pi, np.pi-d, 512)
y = np.linspace(-np.pi, np.pi-d, 512)
X, Y = np.meshgrid(x, y, indexing='ij')
XY = np.vstack([X.flatten(),Y.flatten()]).T

plt.figure(figsize=(15,5))
plt.subplot(131)
plt.imshow(uvp.reshape(16,16).detach().numpy())
plt.title('ground truth')

plt.subplot(132)
plt.imshow(inter(x,y))
plt.title('BivariateSpline')

plt.subplot(133)
plt.imshow(grid_2d(XY).reshape(512,512).detach().numpy())
plt.title('prediction')

How to select device?

  • torch-cubic-b-spline-grid version: 0.0.9
  • Python version: 3.12.1
  • Operating System: Debian

Perhaps this is due to my unfamiliarity with nn.Module, but how do you select which device to use when setting up the spline grid object? I didn't see any reference to this in the examples.

>>> import torch as t
>>> from torch_cubic_spline_grids import CubicBSplineGrid3d
>>> m = CubicBSplineGrid3d(resolution=(6, 10, 10))
>>> %time m(t.rand(50, 3, device='cuda'))

...
File ~/resources/torch-cubic-spline-grids/src/torch_cubic_spline_grids/interpolate_grids.py:129, in interpolate_grid_3d(grid, u, matrix)
    127 idx_h = einops.repeat(idx_h, 'b h -> b d h w', d=4, w=4)
    128 idx_w = einops.repeat(idx_w, 'b w -> b d h w', d=4, h=4)
--> 129 control_points = grid[:, idx_d, idx_h, idx_w]  # (c, b, 4, 4, 4)
    130 control_points = einops.rearrange(control_points, 'c b d h w -> b c d h w')
    131 t = einops.rearrange([t_d, t_h, t_w], 'dhw b -> b dhw')

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

bug in optimise_2d_grid_model.ipynb

  • torch-cubic-b-spline-grid version: 0.0.8
  • Python version:3.8.10
  • Operating System:linux

Description

There is a problem with the code displaying the predicted image and the original image in the optimise_2d_grid_model.ipynb file. The predicted image displays the results calculated by formulas rather than predictions

n = 512
x = torch.tensor(np.linspace([0, 0], [1, 1], n))
xx, yy = torch.meshgrid(x[:, 0], x[:, 1], indexing='xy')
y = torch.sin(xx * 2 * torch.pi) + np.sin(yy * 4 * np.pi)
fig, ax = plt.subplots(1, 2)
ax[0].imshow(y, cmap='gray_r', label='prediction')
ax[1].imshow(target_image, cmap='gray_r', label='ground truth')

So, The correct display should be like this

n = 512
x = torch.tensor(np.linspace([0, 0], [1, 1], n))
xx, yy = torch.meshgrid(x[:, 0], x[:, 1], indexing='xy')
x = torch.vstack([xx.flatten(), yy.flatten()]).T
y = grid_2d(x).detach().numpy().reshape(512,512)
fig, ax = plt.subplots(1, 2)
ax[0].imshow(y, cmap='gray_r', label='prediction')
ax[1].imshow(target_image, cmap='gray_r', label='ground truth')

What I Did

Paste the command(s) you ran and the output.
If there was a crash, please include the traceback here.

range of [0, 1]

I noticed that in the implementation of cubic spline interpolation, the grid coordinate system is confined to the range of [0, 1]. I am keen to understand the reasoning behind this design choice.

Could you please shed some light on why the coordinate system was chosen to be within this range? Is this due to specific mathematical or computational considerations, such as enhancing precision or simplifying calculations? Or does this range better suit certain types of applications?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.