break all the things
This commit is contained in:
parent
e991228047
commit
aa4d221f8c
|
@ -111,6 +111,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
|||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.3"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.1.2"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
@ -300,3 +306,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
|||
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.1"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
||||
git-tree-sha1 = "7fcb55117550e1c195a646947135cc9aac1e2afc"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.1.0+"
|
||||
|
|
|
@ -22,6 +22,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
|||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
[compat]
|
||||
NNlib = "0.6"
|
||||
|
|
|
@ -12,9 +12,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanP
|
|||
|
||||
@reexport using NNlib
|
||||
|
||||
using Tracker
|
||||
using Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
using Zygote
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -196,33 +196,5 @@ end
|
|||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
|
|
|
@ -221,7 +221,6 @@ end
|
|||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
|
||||
|
@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
|
@ -267,37 +265,28 @@ function desc(rnn)
|
|||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker
|
||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||
using Zygote: @adjoint
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
|
@ -309,7 +298,7 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
|
|
|
@ -42,21 +42,6 @@ end
|
|||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
|
|
|
@ -129,10 +129,6 @@ function argmax(xs...)
|
|||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
||||
|
||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
# TODO probably still want this as a custom adjoint Zygote
|
||||
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
using Juno
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Zygote: Params, gradient
|
||||
import Base.depwarn
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
update!(x, -apply!(opt, x, x̄))
|
||||
end
|
||||
|
||||
function update!(opt, xs::Params, gs)
|
||||
|
@ -12,15 +12,6 @@ function update!(opt, xs::Params, gs)
|
|||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
# Callback niceties
|
||||
call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import Adapt: adapt, adapt_storage
|
||||
import .Tracker: IdSet
|
||||
import .Zygote: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
@ -39,7 +39,7 @@ end
|
|||
function params(m)
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
p isa AbstractArray{<:Real} &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
|
@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m)
|
|||
|
||||
function mapparams(f, m)
|
||||
mapleaves(m) do x
|
||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||
x isa Union{AbstractArray,Number} ? f(x) :
|
||||
x
|
||||
x isa Union{AbstractArray,Number} ? f(x) : x
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue