Update track function
This commit is contained in:
parent
b4ba7df03a
commit
6a41f823c8
@ -181,31 +181,31 @@ end
|
|||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||||
|
@ -266,7 +266,7 @@ function desc(rnn)
|
|||||||
end
|
end
|
||||||
|
|
||||||
import Flux.Tracker
|
import Flux.Tracker
|
||||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies, track_kw
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user