break all the things
This commit is contained in:
parent
e991228047
commit
aa4d221f8c
|
@ -111,6 +111,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
||||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||||
version = "0.10.3"
|
version = "0.10.3"
|
||||||
|
|
||||||
|
[[IRTools]]
|
||||||
|
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||||
|
git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53"
|
||||||
|
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||||
|
version = "0.1.2"
|
||||||
|
|
||||||
[[InteractiveUtils]]
|
[[InteractiveUtils]]
|
||||||
deps = ["Markdown"]
|
deps = ["Markdown"]
|
||||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||||
|
@ -300,3 +306,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
||||||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
version = "0.8.1"
|
version = "0.8.1"
|
||||||
|
|
||||||
|
[[Zygote]]
|
||||||
|
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
||||||
|
git-tree-sha1 = "7fcb55117550e1c195a646947135cc9aac1e2afc"
|
||||||
|
repo-rev = "master"
|
||||||
|
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||||
|
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
version = "0.1.0+"
|
||||||
|
|
|
@ -22,6 +22,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||||
|
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||||
|
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
NNlib = "0.6"
|
NNlib = "0.6"
|
||||||
|
|
|
@ -12,9 +12,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanP
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
|
||||||
using Tracker
|
using Zygote
|
||||||
using Tracker: data
|
|
||||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
|
|
@ -196,33 +196,5 @@ end
|
||||||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
|
||||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
|
||||||
|
|
||||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
|
||||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||||
|
|
|
@ -221,7 +221,6 @@ end
|
||||||
# Interface
|
# Interface
|
||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
|
||||||
using .CuArrays.CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using .CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
|
@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
|
||||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||||
|
|
||||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||||
|
@ -267,37 +265,28 @@ function desc(rnn)
|
||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
import Flux.Tracker
|
using Zygote: @adjoint
|
||||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
|
result = forward(desc(m), x, h)
|
||||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
|
||||||
result = istrain(m, h, x) ?
|
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
return result[2], result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = istrain(m, h, x) ?
|
result = forward(desc(m), x, h)
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
return result[2], result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = istrain(m, h, x) ?
|
result = forward(desc(m), x, h[1], h[2])
|
||||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h[1], h[2])
|
|
||||||
return (result[2], result[3]), result[1]
|
return (result[2], result[3]), result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||||
result, function (Δ)
|
result, function (Δ)
|
||||||
y, ho = result
|
y, ho = result
|
||||||
|
@ -309,7 +298,7 @@ end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||||
result, function (Δ)
|
result, function (Δ)
|
||||||
y, ho = result
|
y, ho = result
|
||||||
|
|
|
@ -42,21 +42,6 @@ end
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
|
||||||
_truncate(x::Tuple) = _truncate.(x)
|
|
||||||
|
|
||||||
"""
|
|
||||||
truncate!(rnn)
|
|
||||||
|
|
||||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
|
||||||
state is preserved. See also `reset!`.
|
|
||||||
|
|
||||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
|
||||||
|
|
||||||
rnn.state = Tracker.data(rnn.state)
|
|
||||||
"""
|
|
||||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
reset!(rnn)
|
reset!(rnn)
|
||||||
|
|
||||||
|
|
|
@ -129,10 +129,6 @@ function argmax(xs...)
|
||||||
return onecold(xs...)
|
return onecold(xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
# Ambiguity hack
|
# TODO probably still want this as a custom adjoint Zygote
|
||||||
|
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
|
||||||
|
|
||||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
|
||||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
using Juno
|
using Juno
|
||||||
import Flux.Tracker: Params, gradient, data, update!
|
import Zygote: Params, gradient
|
||||||
import Base.depwarn
|
import Base.depwarn
|
||||||
|
|
||||||
function update!(opt, x, x̄)
|
function update!(opt, x, x̄)
|
||||||
update!(x, -apply!(opt, x, data(x̄)))
|
update!(x, -apply!(opt, x, x̄))
|
||||||
end
|
end
|
||||||
|
|
||||||
function update!(opt, xs::Params, gs)
|
function update!(opt, xs::Params, gs)
|
||||||
|
@ -12,15 +12,6 @@ function update!(opt, xs::Params, gs)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# Added as an internal API but everyone started using it.
|
|
||||||
function _update_params!(opt, xs)
|
|
||||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
|
||||||
for x in xs
|
|
||||||
update!(opt, x, Tracker.grad(x))
|
|
||||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Callback niceties
|
# Callback niceties
|
||||||
call(f, xs...) = f(xs...)
|
call(f, xs...) = f(xs...)
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import Adapt: adapt, adapt_storage
|
import Adapt: adapt, adapt_storage
|
||||||
import .Tracker: IdSet
|
import .Zygote: IdSet
|
||||||
|
|
||||||
children(x) = ()
|
children(x) = ()
|
||||||
mapchildren(f, x) = x
|
mapchildren(f, x) = x
|
||||||
|
@ -39,7 +39,7 @@ end
|
||||||
function params(m)
|
function params(m)
|
||||||
ps = Params()
|
ps = Params()
|
||||||
prefor(p ->
|
prefor(p ->
|
||||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
p isa AbstractArray{<:Real} &&
|
||||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||||
m)
|
m)
|
||||||
return ps
|
return ps
|
||||||
|
@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m)
|
||||||
|
|
||||||
function mapparams(f, m)
|
function mapparams(f, m)
|
||||||
mapleaves(m) do x
|
mapleaves(m) do x
|
||||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
x isa Union{AbstractArray,Number} ? f(x) : x
|
||||||
x isa Union{AbstractArray,Number} ? f(x) :
|
|
||||||
x
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue