diff --git a/Manifest.toml b/Manifest.toml index 8f2f0fad..06348d88 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -111,6 +111,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" uuid = "f6369f11-7733-5829-9624-2563aa707210" version = "0.10.3" +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.1.2" + [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -300,3 +306,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"] git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" version = "0.8.1" + +[[Zygote]] +deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"] +git-tree-sha1 = "7fcb55117550e1c195a646947135cc9aac1e2afc" +repo-rev = "master" +repo-url = "https://github.com/FluxML/Zygote.jl.git" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.1.0+" diff --git a/Project.toml b/Project.toml index 85972f07..bd4820e7 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] NNlib = "0.6" diff --git a/src/Flux.jl b/src/Flux.jl index eccdd6a7..ef43edeb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,9 +12,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanP @reexport using NNlib -using Tracker -using Tracker: data -export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param +using Zygote include("optimise/Optimise.jl") using .Optimise diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index fac35a72..214cc108 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -196,33 +196,5 @@ end (BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) -batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = - track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) - -@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = +@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 09f6d43c..7ad14102 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -221,7 +221,6 @@ end # Interface import ..Flux: Flux, relu -import ..Tracker: TrackedArray using .CuArrays.CUDAnative using .CuArrays: @cuindex, cudims @@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) return dst end -CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} -CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} -CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} -CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}} +CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} +CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} +CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} function copyparams!(m::CuRNNs, d::RNNDesc) @@ -267,37 +265,28 @@ function desc(rnn) return d end -import Flux.Tracker -import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies +using Zygote: @adjoint -istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) - -function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) +function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} + result = forward(desc(m), x, h) return result[2], result[1] end -function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) +function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} + result = forward(desc(m), x, h) return result[2], result[1] end -function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : - forward(desc(m), x, h[1], h[2]) +function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} + result = forward(desc(m), x, h[1], h[2]) return (result[2], result[3]), result[1] end -(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) +@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) reserve, result = forwardTrain(desc(m), data(x), data(h)) result, function (Δ) y, ho = result @@ -309,7 +298,7 @@ end end end -@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) +@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b) reserve, result = forwardTrain(desc(m), data.((x, h, c))...) result, function (Δ) y, ho = result diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 61bbec4e..03e3b323 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -42,21 +42,6 @@ end Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") -_truncate(x::AbstractArray) = Tracker.data(x) -_truncate(x::Tuple) = _truncate.(x) - -""" - truncate!(rnn) - -Truncates the gradient of the hidden state in recurrent layers. The value of the -state is preserved. See also `reset!`. - -Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to - - rnn.state = Tracker.data(rnn.state) -""" -truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m) - """ reset!(rnn) diff --git a/src/onehot.jl b/src/onehot.jl index 172591f6..333922fa 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -129,10 +129,6 @@ function argmax(xs...) return onecold(xs...) end -# Ambiguity hack - -a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) -a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b) - -onecold(x::TrackedVector, l...) = onecold(data(x), l...) -onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) +# TODO probably still want this as a custom adjoint Zygote +# onecold(x::TrackedVector, l...) = onecold(data(x), l...) +# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index ab8be578..bd965f00 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,9 +1,9 @@ using Juno -import Flux.Tracker: Params, gradient, data, update! +import Zygote: Params, gradient import Base.depwarn function update!(opt, x, x̄) - update!(x, -apply!(opt, x, data(x̄))) + update!(x, -apply!(opt, x, x̄)) end function update!(opt, xs::Params, gs) @@ -12,15 +12,6 @@ function update!(opt, xs::Params, gs) end end -# Added as an internal API but everyone started using it. -function _update_params!(opt, xs) - depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop) - for x in xs - update!(opt, x, Tracker.grad(x)) - x.tracker.grad = Tracker.zero_grad!(x.tracker.grad) - end -end - # Callback niceties call(f, xs...) = f(xs...) runall(f) = f diff --git a/src/treelike.jl b/src/treelike.jl index 443a91e2..07935e55 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,5 +1,5 @@ import Adapt: adapt, adapt_storage -import .Tracker: IdSet +import .Zygote: IdSet children(x) = () mapchildren(f, x) = x @@ -39,7 +39,7 @@ end function params(m) ps = Params() prefor(p -> - Tracker.istracked(p) && Tracker.isleaf(p) && + p isa AbstractArray{<:Real} && !any(p′ -> p′ === p, ps) && push!(ps, p), m) return ps @@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m) function mapparams(f, m) mapleaves(m) do x - Tracker.istracked(x) ? param(f(Tracker.data(x))) : - x isa Union{AbstractArray,Number} ? f(x) : - x + x isa Union{AbstractArray,Number} ? f(x) : x end end