Code Monkey home page Code Monkey logo

Comments (6)

roflmaostc avatar roflmaostc commented on August 17, 2024 1

I think that might also explain the discrepancy in the initial post.
So all fine!

from torchkbnufft.

roflmaostc avatar roflmaostc commented on August 17, 2024

Currently my ktraj shape is:

ktraj.shape

torch.Size([1, 2, 412160])

Repeating the 0th dimension does only increase the runtime. So I think what I'm doing, looks ok?

from torchkbnufft.

mmuckley avatar mmuckley commented on August 17, 2024

Hello @roflmaostc, I think this is expected.

The first factor is torch.fft.fft only does a 1D FFT, so it is doing a factor N fewer FFTs to begin with.

The second factor is the default grid size is to use a 2X oversampled grid, so based on that it's not an apples-to-apples comparison. This gives some pretty massive 1024 x 1024 2D FFTs inside the NUFFT vs. the 512 1D FFT that you're using for FFT. You could be more conservative and use the more-standard 1.5-factor oversampling by setting grid_size to (768, 768).

The third is that due to the high-level implementation in Python, our interpolation is quite a bit slower. There are several mitigations for this, such as broadcasting across sensitivity coils, but it will never be as fast as a compiled implementation. The advantage is that you never have to worry about compiling torchkbnufft, but the disadvantage is speed.

The last item is I've always had a bit of trouble squeezing out performance on the CPU for the adjoint, and have generally observed the GPU to be much closer to forward performance.

One possible mitigation, if you have a parametrization of your problem that's amenable to rewriting in terms of A'A, you can use the Toeplitz NUFFT for the forward-backward, which only uses FFTs and no interpolation.

from torchkbnufft.

roflmaostc avatar roflmaostc commented on August 17, 2024

Thanks for your detailed reply!

Yeah, that's right.

Your point regarding ToepNUFFT is interesting, and exactly what I'm looking for.
The kernel it spits out, is that applied with ifft(fft(pad(arr)) * kernel)? where pad would apply the padding?

I'm asking, since we would expect the cost of of 4 (padding) x 2 (back and forth)=8 FFTs.

But in my cases, a naive application of ifft(fft(pad(arr)) * kernel) with precalculated kernels is still ~40 times faster.

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import napari
import glob
import imageio.v3 as iio
import os
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"

N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)

def toeplitz(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
    
    f = lambda x: toep_ob(x, kernel)
    
    return f, kernel

def toeplitz2(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)
    

    kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
    def apply_kernel(x):
        x_pad = torch.zeros(x.shape[0], 1, 2 * x.shape[2], 2 * x.shape[3], device=device)
        x_pad[:, :, 0:x.shape[2], 0:x.shape[3]] = x
        
        res = torch.fft.ifft2(torch.fft.fft2(x_pad) * kernel)[:, :, 0:x.shape[2], 0:x.shape[3]]
        return res
    
    return apply_kernel


toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)



%%time
arr2 = toeplitz_f(voxels)
CPU times: user 20.8 ms, sys: 547 µs, total: 21.3 ms
Wall time: 20.7 ms

%%time
arr3 = apply_kernel(voxels)
CPU times: user 928 µs, sys: 0 ns, total: 928 µs
Wall time: 524 µs

from torchkbnufft.

mmuckley avatar mmuckley commented on August 17, 2024

Are you running on the GPU? You have to call torch.cuda.synchronize(). I get almost the exact same times on the CPU for this code.

import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
import matplotlib.pyplot as plt
import glob
import os
import torch.nn.functional as F
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = "cpu"

N = 200
N_z = 300
N_angles = int(np.ceil(np.pi * N / 2))
voxels = (torch.zeros(N_z, 1, N, N, dtype=torch.float32) + 1j * 0).to(device).to(torch.complex64)

def toeplitz(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=5).to(device)
    
    f = lambda x: toep_ob(x, kernel)
    
    return f, kernel

def toeplitz2(arr, N_angles):
    N = arr.shape[-1]
    k_1d = np.reshape(np.linspace(-np.pi, np.pi, N + 1)[:-1], (-1, 1))
    angles = np.reshape(np.linspace(0, np.pi, N_angles + 1)[:-1], (1, -1))
    k_y = np.cos(angles) * k_1d
    k_x = np.sin(angles) * k_1d
    ks = np.reshape(np.stack((k_y, k_x)), (2, -1))
    ktraj = torch.tensor(ks, dtype=torch.float32).to(device)
    
    ktraj = ktraj.repeat(1, 1, 1)
    
    toep_ob = tkbn.ToepNufft().to(device)
    kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size=(N, N), numpoints=8).to(device)

    kernel = kernel.reshape(kernel.shape[0], 1, kernel.shape[1], kernel.shape[2])
    def apply_kernel(x):
        im_size = torch.tensor(x.shape[2:])

        grid_size = torch.tensor(
            kernel.shape[-len(kernel.shape[2:]) :], dtype=torch.long, device=kernel.device
        )
        pad_sizes = []
        for (gd, im) in zip(grid_size.flip((0,)), im_size.flip((0,))):
            pad_sizes.append(0)
            pad_sizes.append(int(gd - im))
        x_pad = F.pad(x, pad_sizes)
        print(x_pad.shape)
        print(kernel.shape)
        
        res = torch.fft.fftn(torch.fft.fftn(x_pad, dim=[-2, -1], norm="ortho") * kernel, dim=[-2, -1], norm="ortho")[:, :, :im_size[-2], :im_size[-1]]
        return res
    
    return apply_kernel


toeplitz_f, kernel = toeplitz(voxels, N_angles)
apply_kernel = toeplitz2(voxels, N_angles)


import time

start = time.perf_counter()
arr2 = toeplitz_f(voxels)
end = time.perf_counter()
print(f"Toeplitz: {end-start}")

start = time.perf_counter()
arr3 = apply_kernel(voxels)
end = time.perf_counter()
print(f"simple pad: {end-start}")

from torchkbnufft.

roflmaostc avatar roflmaostc commented on August 17, 2024

You're right, torch.cuda.synchronize fixed it.

Then I get ~20ms and ~27ms. A little overhead seems ok :)

Thanks for helping me!

from torchkbnufft.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.