Code Monkey home page Code Monkey logo

Comments (9)

MikeInnes avatar MikeInnes commented on August 18, 2024

Hey @osofr, I will take a look at this today and see what I can do. The egregiousness of the performance hit is actually a good thing here, because something very simple has gone wrong; most likely a kernel just slipped through the cracks.

So just to clarify, the slowdown caused by a statement of the form x[i,:] = ..., right, but the forward pass x[i,:] is fine? Is this call falling back to scalar indexing (it will error if you call CuArrays.allowscalar(false) beforehand)?

from cuarrays.jl.

osofr avatar osofr commented on August 18, 2024

Thank you so much @MikeInnes! Yes, this is exactly right. The forward pass is fine and is much faster on GPU. It errors out when setting CuArrays.allowscalar(false), at least the last example below. I haven't checked if it will also error out when calling back, but I am fairly certain this is the case. Thank you so much for your help! I wish I knew how to fix this myself, will hopefully learn something from your fix.

xg = rand(2,100) |> gpu;
repl = zeros(100) |> gpu;
@time xg[1,:] = repl;

from cuarrays.jl.

MikeInnes avatar MikeInnes commented on August 18, 2024

Ok, so this is happening just because we don't have a setindex kernel on the GPU; pinging @SimonDanisch who can hopefully implement one in GPUArrays fairly quickly.

from cuarrays.jl.

osofr avatar osofr commented on August 18, 2024

Ok, thanks!

@SimonDanisch, anything I can do to expedite this? Can this kernel be borrowed from somewhere? Am I the first person to try to overwrite a row/column in GPU array?

from cuarrays.jl.

MikeInnes avatar MikeInnes commented on August 18, 2024

Can you try Pkg.update()? Looks like this was fixed on master and just not tagged yet.

from cuarrays.jl.

osofr avatar osofr commented on August 18, 2024

Ok, fairly certain I tried this yesterday, but will do so right now. Just to confirm, I should update GPUArrays? Will take me a few minutes to spin up GPU instance, but will get back shortly.

from cuarrays.jl.

MikeInnes avatar MikeInnes commented on August 18, 2024

Yeah, GPUArrays has the fix. If not, would be useful to narrow down what the setindex is being called with.

from cuarrays.jl.

osofr avatar osofr commented on August 18, 2024

Woohoo, it works! Not as fast as CPU, but it will do. I will re-run my Flux training code, just to confirm that the performance has improved and will close the issue then. Just out of curiosity, would you mind pointing me to the GPUArrays commit that fixes this setindex issue? I'd like to understand it for future hacking. Thank you so much @MikeInnes!

julia> using Flux, CuArrays
julia> testx = rand(2,100);
julia> x = param(testx);
julia> idx = (1,:); # (1, Colon())
julia> l = Flux.getindex(x,idx...);
julia> l2 = sum(l);
julia> Flux.back!(l2);
julia> @time Flux.back!(l2);
  0.000011 seconds (9 allocations: 2.844 KiB)

julia> xg = param(testx) |> gpu;
julia> l = Flux.getindex(xg,idx...);
julia> l2 = sum(l);
julia> Flux.back!(l2);
julia> @time Flux.back!(l2);
  0.000246 seconds (104 allocations: 2.844 KiB)

from cuarrays.jl.

MikeInnes avatar MikeInnes commented on August 18, 2024

I think you're just timing kernel launch overhead there, which is significant when the size of the array is small. Hopefully the times for the overall model are much better.

Here's the GPU code for setindex!. Hacking on GPU kernels in Julia is pretty easy and straightforward now, so I'd definitely recommend playing around with it, maybe via CUDAnative.

from cuarrays.jl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.