diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 6e2c9e75..302da233 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -181,31 +181,31 @@ end batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) @grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index ed65f5e7..f58e3b05 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -266,7 +266,7 @@ function desc(rnn) end import Flux.Tracker -import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies +import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies, track_kw istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))