Code Monkey home page Code Monkey logo

Comments (7)

DhairyaLGandhi avatar DhairyaLGandhi commented on July 26, 2024 1

Likely due to FillArrays.Ones which may come from the push! adjoint.

I would get rid of the loops to construct the models etc and see how that performs.

As an aside - I'd love an EfficientNet implementation in Metalhead as a PR ;)

from nnlibcuda.jl.

pxl-th avatar pxl-th commented on July 26, 2024 1

Yes, I've changed the feature extraction part (encoder) for the model to use map instead of loops and now can take gradients.
Although, having support for push!, in this case, would've been nice as well.

Here's fun animation of the training dynamics on a selected image from a small dataset, if anyone is curious :)
output

from nnlibcuda.jl.

ToucheSir avatar ToucheSir commented on July 26, 2024

Can you reduce this down to a MWE? A MethodError shouldn't be too difficult to troubleshoot, we just need to trace the data coming into batchnorm.

from nnlibcuda.jl.

pxl-th avatar pxl-th commented on July 26, 2024

Yeah, should've done that at the beginning :)
I've figured that if both encoder blocks and decoder blocks in my code end with BatchNorm, I might as well construct everything using them.

MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function decode(decoder, features)
    features = features[end:-1:1]
    head, skips = features[1], features[2:end]

    x = head
    for (i, block) in enumerate(decoder)
        if i  length(skips)
            x = cat(x, skips[i]; dims=3)
        end
        x = block(x)
    end
    x
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    encoder = Chain(BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
    decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
    θ = params(encoder, decoder)

    gradient(θ) do
        features = encode(encoder, x)
        out = decode(decoder, features)
        sum(out)
    end
end
main()

Produces the same error:

ERROR: LoadError: MethodError: no method matching
  ∇batchnorm(::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 4}, ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}}, ::CUDA.CuArray{Float32, 1}, ::CUDA.CuArray{Float32, 1}, ::Float32; cache=nothing, alpha=1, beta=0, eps=1.0f-5, training=true)
Closest candidates are:
  ∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:81
  ∇batchnorm(::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, 2}, ::CUDA.CuArray{T, N} where N, ::CUDA.CuArray{T, N} where N, ::Any; cache, eps, alpha, beta, training) where T<:Union{Float32, Float64} at /home/pxl-th/.julia/packages/NNlibCUDA/Oc2CZ/src/cudnn/batchnorm.jl:71
Stacktrace:
  [1] (::Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:17
  [2] (::Flux.CUDAint.var"#793#back#4"{Flux.CUDAint.var"#batchnorm_pullback#2"{Base.Iterators.Pairs{Symbol, Union{Nothing, Real}, NTuple{5, Symbol}, NamedTuple{(:cache, :alpha, :beta, :eps, :training), Tuple{Nothing, Int64, Int64, Float32, Bool}}}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 4}, CUDA.CuArray{Float32, 1}, CUDA.CuArray{Float32, 1}, Float32}})(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Flux.CUDAint ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:65
  [3] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:9 [inlined]
  [4] (::typeof((λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [5] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/cuda/cudnn.jl:6 [inlined]
  [6] (::typeof((λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [7] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
  [8] (::typeof((encode)))(Δ::Vector{Union{Nothing, CUDA.CuArray{Float32, 4}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:39 [inlined]
 [10] (::typeof((λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#90#91"{Zygote.Params, typeof((λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [12] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [13] main()
    @ Main ~/projects/Segmentation.jl/src/mwe.jl:38
 [14] top-level scope
    @ ~/projects/Segmentation.jl/src/mwe.jl:45
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:45

from nnlibcuda.jl.

pxl-th avatar pxl-th commented on July 26, 2024

Indeed, it comes from the push!.
However, if you replace identity activation with any other activation function (e.g. relu), the error disappears.
But in MBConv the last BatchNorm has no activation.

Here's an even smaller MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device
    encoder = Chain(BatchNorm(3, identity), BatchNorm(3, identity)) |> device |> trainmode!
    θ = params(encoder)
    gradient(θ) do
        sum(reduce(+, encode(encoder, x)))
    end
end
main()

Also I'm not sure how you would get rid of the loops, without unrolling them manually and without loss of generality.
I want to be able to pass different encoders, where they can have different feature extraction depth.
Having separate encoding and decoding stages makes things easier.
Similar to how it is done in segmentation_models python package.

Maybe, for the GPU, we should "materialize" FillArrays.Ones into CuArray if we get one?
Especially, since on CPU this is working fine and it would make sense to have the same support for GPU.

from nnlibcuda.jl.

pxl-th avatar pxl-th commented on July 26, 2024

Similar thing happens if you replace BatchNorm with Conv. And the error disappears if you specify non-identity activation function.

MWE:

using Flux

function encode(encoder, x)
    features = typeof(x)[]
    for block in encoder
        x = block(x)
        push!(features, x)
    end
    features
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    encoder = Chain(
        Conv((3, 3), 3=>3, identity; pad=SamePad()),
        Conv((3, 3), 3=>3, identity; pad=SamePad()),
    ) |> device |> trainmode!
    θ = params(encoder)

    gradient(θ) do
        sum(reduce(+, encode(encoder, x)))
    end
end
main()

Error:

ERROR: LoadError: TaskFailedException

    nested task error: MethodError: no method matching
      gemm!(::Val{false}, ::Val{true}, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::CUDA.CuPtr{Float32}, ::Float32, ::Ptr{Float32})
    Closest candidates are:
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64}) at /home/pxl-th/.julia/packages/NNlib/YKZXm/src/gemm.jl:32
      ...
    Stacktrace:
     [1] macro expansion
       @ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:156 [inlined]
     [2] (::NNlib.var"#752#threadsfor_fun#391"{Float32, Array{Float32, 3}, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CUDA.CuArray{Float32, 5}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, Int64, Int64, Int64, UnitRange{Int64}})(onethread::Bool)
       @ NNlib ./threadingconstructs.jl:81
     [3] #invokelatest#2
       @ ./essentials.jl:708 [inlined]
     [4] invokelatest
       @ ./essentials.jl:706 [inlined]
     [5] macro expansion
       @ ./threadingconstructs.jl:86 [inlined]
     [6] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; col::Array{Float32, 3}, alpha::Float32, beta::Float32)
       @ NNlib ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:148
     [7] ∇conv_data_im2col!
       @ ~/.julia/packages/NNlib/YKZXm/src/impl/conv_im2col.jl:127 [inlined]
     [8] (::NNlib.var"#162#166"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}, CUDA.CuArray{Float32, 5}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})()
       @ NNlib ./threadingconstructs.jl:169
Stacktrace:
  [1] sync_end(c::Channel{Any})
    @ Base ./task.jl:369
  [2] macro expansion
    @ ./task.jl:388 [inlined]
  [3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:228
  [4] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 3, 3, 1, (1, 1, 1), (1, 1, 1, 1, 0, 0), (1, 1, 1), false})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:217
  [5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::CUDA.CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 3, 3, 1, (1, 1), (1, 1, 1, 1), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151
  [6] ∇conv_data!
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:151 [inlined]
  [7] #∇conv_data#89
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:104 [inlined]
  [8] ∇conv_data
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:101 [inlined]
  [9] #204
    @ ~/.julia/packages/NNlib/YKZXm/src/conv.jl:313 [inlined]
 [10] unthunk
    @ ~/.julia/packages/ChainRulesCore/BYuIz/src/differentials/thunks.jl:192 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:55 [inlined]
 [12] map
    @ ./tuple.jl:215 [inlined]
 [13] map
    @ ./tuple.jl:216 [inlined]
 [14] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:56 [inlined]
 [15] ZBack
    @ ~/.julia/packages/Zygote/TaBlo/src/compiler/chainrules.jl:91 [inlined]
 [16] Pullback
    @ ~/.julia/packages/Flux/Zz9RI/src/layers/conv.jl:165 [inlined]
 [17] (::typeof((λ)))(Δ::FillArrays.Ones{Float32, 4, NTuple{4, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:6 [inlined]
 [19] (::typeof((encode)))(Δ::Vector{CUDA.CuArray{Float32, 4}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/projects/Segmentation.jl/src/mwe.jl:21 [inlined]
 [21] (::typeof((λ)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [22] (::Zygote.var"#90#91"{Zygote.Params, typeof((λ)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:348
 [23] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [24] main()
    @ Main ~/projects/Segmentation.jl/src/mwe.jl:20
 [25] top-level scope
    @ ~/projects/Segmentation.jl/src/mwe.jl:26
in expression starting at /home/pxl-th/projects/Segmentation.jl/src/mwe.jl:26

from nnlibcuda.jl.

ToucheSir avatar ToucheSir commented on July 26, 2024

This doesn't resolve the underlying issue, but you shouldn't have to use any mutation or explicit loops to write an EfficientNet-style architecture. Here's an equivalent version that works:

using Flux

encode(encoder, x) = map(f -> f(x), encoder)

function decode(decoder, features)
    nskips = length(features) - 1
    skip_blocks, rest_blocks = decoder[1:nskips], decoder[nskips:end]
    xs = foldl(zip(skip_blocks, features[nskips:-1:1]); init=features[end]) do acc, (f, x)
        f(cat(acc, x; dims=ndims(x) - 1)) # ndims(x) - 1 == 3 here, but is more general 
    end
    return rest_blocks(xs)
end

function main()
    device = gpu
    x = randn(Float32, 10, 10, 3, 1) |> device

    # encoders are not chained (they are run in parallel), so don't make them a Chain
    encoder = (BatchNorm(3), BatchNorm(3), BatchNorm(3)) |> device |> trainmode!
    decoder = Chain(BatchNorm(6), BatchNorm(9)) |> device |> trainmode!
    θ = params(encoder, decoder)

    gradient(θ) do
        features = encode(encoder, x)
        out = decode(decoder, features)
        sum(out)
    end
end
main()

from nnlibcuda.jl.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.