Code Monkey home page Code Monkey logo

Comments (14)

MultiPath avatar MultiPath commented on July 17, 2024

Thanks for sharing the profiling. So in your view what is the conclusion why the pytorch binding is slower than the original implementation? You mentioned the fullyfusedmlp is not too much faster than the pytorch native code? The author said it was significantly faster though NVlabs/tiny-cuda-nn#14 (comment)

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

Detailed benchmarking is needed to see exactly which part is slower, but unfortunately I haven't found a good way to start.

For the NeRF experiments, the original implementation adopts lots of techniques to speed up the ray marching process (e.g., pre-calculate a multi-level density grid), but I haven't succeeded in replicating it for now (the cuda_raymarching flag).

For the FFMLP, I only changed the data interface of the original implementation. I'm not sure if the performance is decent, since pytorch also uses cutlass for fp16 matmuls, and the original implementation doesn't provide a comparison with pytorch. Also, since the most time-consuming part is the HashGrid Encoder, I haven't paid much attention to MLP.

Hope these can be helpful, and any advice is welcome ;)

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

For the hashgrid encoder, I tried to replace the original implementation with my implementation in instant-ngp, and the speed is comparable. So the problem may falls on pytorch calling these kernels inefficiently, but I still have no idea why...

For anyone who also wants to try, insert the followings in grid.h and comment the original kernel_grid function.

template <uint32_t D, uint32_t C>
__device__ uint32_t get_grid_index(const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
	uint32_t stride = 1;
	uint32_t index = 0;

	#pragma unroll
    for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
        //printf("get_grid_index d=%d, pos_grid[d]=%d, stride=%d, reso=%d\n", d, pos_grid[d], stride, resolution);
        index += pos_grid[d] * stride;
        stride *= (resolution + 1);
    }

    if (stride > hashmap_size) {
        //printf("hash because %d > %d\n", stride, hashmap_size);
        index = fast_hash<D>(pos_grid);
        //printf("hashed (%d, %d) = %d to %d in %d\n", pos_grid[0], pos_grid[1], pos_grid[0] + resolution * pos_grid[1], index % hashmap_size, hashmap_size);
    }

	return (index % hashmap_size) * C + ch;
}


template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grid(

	const uint32_t B,
	const uint32_t num_grid_features,
	const uint32_t* offsets,
	const uint32_t H,
	const float S,
	const float quantize_threshold,
	float max_level,
	const float* __restrict__ max_level_gpu,
	const InterpolationType interpolation_type,
	const GridType grid_type,
	const scalar_t* __restrict__ grid,
	const float* __restrict__ inputs,
	scalar_t* __restrict__ outputs,
	float* __restrict__ dy_dx

    // const scalar_t * __restrict__ inputs, 
    // const scalar_t * __restrict__ grid, 
    // const int * __restrict__ offsets, 
    // scalar_t * __restrict__ outputs, 
    // const uint32_t B, const uint32_t L, const float S, const uint32_t H,
    // const bool calc_grad_inputs, 
    // scalar_t * __restrict__ dy_dx
) {
    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (b >= B) return;

    const uint32_t level = blockIdx.y;
    
    // locate
    grid += offsets[level] * C;
    outputs += level * B * C;

    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
    const float scale = exp2f(level * S) * H - 1.0f;
    const uint32_t resolution = (uint32_t)ceil(scale) + 1;
    
    // calculate coordinate
    float pos[D];
    uint32_t pos_grid[D];

    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        pos[d] = (float)inputs[d * B + b] * scale + 0.5f;
        pos_grid[d] = floorf(pos[d]);
        pos[d] -= (float)pos_grid[d];
    }

    //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);

    // interpolate
    scalar_t results[C] = {0}; // temp results in register

    #pragma unroll
    for (uint32_t idx = 0; idx < (1 << D); idx++) {
        float w = 1;
        uint32_t pos_grid_local[D];

        #pragma unroll
        for (uint32_t d = 0; d < D; d++) {
            if ((idx & (1 << d)) == 0) {
                w *= 1 - pos[d];
                pos_grid_local[d] = pos_grid[d];
            } else {
                w *= pos[d];
                pos_grid_local[d] = pos_grid[d] + 1;
            }
        }

        uint32_t index = get_grid_index<D, C>(0, hashmap_size, resolution, pos_grid_local);

        // writing to register (fast)
        #pragma unroll
        for (uint32_t ch = 0; ch < C; ch++) {
            results[ch] += (scalar_t)(w * (float)grid[index + ch]);
        }

        //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
    }    

    // writing to global memory (slow)
    #pragma unroll
    for (uint32_t ch = 0; ch < C; ch++) {
        outputs[B * ch + b] = results[ch]; 
    }

    // prepare dy_dx for calc_grad_inputs
    // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
    if (dy_dx) {
		// not implemented
		;
    }
}

from torch-ngp.

MultiPath avatar MultiPath commented on July 17, 2024

Did you ever compare the speed of your FFMLP and pytorch native MLP (fp16) solely instead of inside NeRF? Say you prepare a batch of features and input both MLPs.

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

Yes, some basic testing was done here.

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

Thanks to nsight system profiler, I'm sure the performance of hashgrid encoder and ffmlp kernels are OK.
However, the current pytorch rendering function is really bad, and some points that affects speed are listed here:

  • instant-ngp uses a dynamic resolution, check m_dynamic_res.
  • instant-ngp's rendering function only takes ~50 steps before all rays die (check trace, generate_next_nerf_network_inputs), while the current pytorch implementation forces 128+128 steps.
  • the cat in forward (h = torch.cat([d, geo_feat], dim=-1)) does unnecessary copy.
  • although upsample_steps > 0 enhances visual quality, it nearly doubles the time.

from torch-ngp.

niujinshuchong avatar niujinshuchong commented on July 17, 2024

@ashawkey Thanks for sharing the great repo. I ran the testing script and found the ffmlp is faster in forward and takes more time to backward (3.35 vs 17.37) while the pytorch mlp takes similar to in both forward and backward (3.19 vs 4.84). Do you have any idea why it's the case?

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

@niujinshuchong What's your testing parameters? In fact I got the opposite results.
The output of this testing script on my machine: (time0 = ffmlp, time1 = pytorch mlp):

time1 (fp32 train) = 11.47059154510498
time1 (fp32 back) = 18.597087860107422
time0 (forward) = 7.910304069519043
time0 (backward) = 5.978400230407715
time1 (forward) = 6.140927791595459
time1 (backward) = 7.3114237785339355
time1 (fp32 infer) = 8.96940803527832
time0 (infer) = 3.3681600093841553
time1 (infer) = 4.70147180557251

from torch-ngp.

niujinshuchong avatar niujinshuchong commented on July 17, 2024

@ashawkey I didn't change anything. I ran with python testing/test_ffmlp.py. And here is the results

time1 (fp32 train) = 5.6842241287231445
time1 (fp32 back) = 10.439680099487305
time0 (forward) = 2.60915207862854
time0 (backward) = 16.638975143432617
time1 (forward) = 3.442336082458496
time1 (backward) = 5.590688228607178
time1 (fp32 infer) = 5.916319847106934
time0 (infer) = 1.1610560417175293
time1 (infer) = 3.0545918941497803

There is large difference in forward and backward of ffmlp.

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

@niujinshuchong This is quite surprising... Could you provide more details about your environment, like the GPU and CUDA version? I'm testing with a TITAN RTX on CUDA 11.3. Besides, have you tried the --tcnn flag? Does it have similar speed compared to --ff?

from torch-ngp.

niujinshuchong avatar niujinshuchong commented on July 17, 2024

@ashawkey I am also using RTX 3090 and CUDA 11.3. I add the tcnn to testing/testing/test_ffmlp.py with the following

net2 = tcnn.Network(n_input_dims=INPUT_DIM, n_output_dims=OUTPUT_DIM, network_config={
                    "otype": "FullyFusedMLP",
                    "activation": "ReLU",
                    "output_activation": "None",
                    "n_neurons": HIDDEN_DIM,
                    "n_hidden_layers": NUM_LAYERS,
                })

and here is the output

time1 (fp32 train) = 5.649184226989746
time1 (fp32 back) = 9.604096412658691
time0 (forward) = 2.3817920684814453
time0 (backward) = 16.32137680053711
time1 (forward) = 3.3835198879241943
time1 (backward) = 4.825088024139404
time2 (forward) = 1.7377279996871948
time2 (backward) = 4.14412784576416
time1 (fp32 infer) = 5.628928184509277
time0 (infer) = 1.159168004989624
time1 (infer) = 3.0504961013793945
time2 (infer) = 0.8120319843292236

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

@niujinshuchong It seems --tcnn always has the best performance, maybe you could use it for now.
I still cannot reproduce the strange backward time of FFMLP on my machine, but I'll try on other machines later.

from torch-ngp.

niujinshuchong avatar niujinshuchong commented on July 17, 2024

@ashawkey Thank you very much!

from torch-ngp.

ashawkey avatar ashawkey commented on July 17, 2024

Within 3be8f25, I can close this now.

from torch-ngp.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.