Remove track_kw

This commit is contained in:
Avik Pal 2018-08-10 03:21:05 +05:30
parent 3f6c065523
commit 3affed8ef0
2 changed files with 8 additions and 8 deletions

View File

@ -181,31 +181,31 @@ end
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track_kw(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = @grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)

View File

@ -266,7 +266,7 @@ function desc(rnn)
end end
import Flux.Tracker import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies, track_kw import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))