break all the things

This commit is contained in:
Mike J Innes 2019-03-08 12:06:09 +00:00 committed by Elliot Saba
parent e991228047
commit aa4d221f8c
9 changed files with 42 additions and 96 deletions

View File

@ -111,6 +111,12 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
uuid = "f6369f11-7733-5829-9624-2563aa707210" uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.3" version = "0.10.3"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.1.2"
[[InteractiveUtils]] [[InteractiveUtils]]
deps = ["Markdown"] deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@ -300,3 +306,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e" git-tree-sha1 = "5f6f663890dfb9bad6af75a86a43f67904e5050e"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.8.1" version = "0.8.1"
[[Zygote]]
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
git-tree-sha1 = "7fcb55117550e1c195a646947135cc9aac1e2afc"
repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.1.0+"

View File

@ -22,6 +22,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat] [compat]
NNlib = "0.6" NNlib = "0.6"

View File

@ -12,9 +12,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanP
@reexport using NNlib @reexport using NNlib
using Tracker using Zygote
using Tracker: data
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise

View File

@ -196,33 +196,5 @@ end
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = (BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, @adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)

View File

@ -221,7 +221,6 @@ end
# Interface # Interface
import ..Flux: Flux, relu import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using .CuArrays.CUDAnative using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims using .CuArrays: @cuindex, cudims
@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
return dst return dst
end end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::RNNDesc) function copyparams!(m::CuRNNs, d::RNNDesc)
@ -267,37 +265,28 @@ function desc(rnn)
return d return d
end end
import Flux.Tracker using Zygote: @adjoint
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h)
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1] return result[2], result[1]
end end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? result = forward(desc(m), x, h)
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1] return result[2], result[1]
end end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? result = forward(desc(m), x, h[1], h[2])
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1] return (result[2], result[3]), result[1]
end end
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) @adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h)) reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ) result, function (Δ)
y, ho = result y, ho = result
@ -309,7 +298,7 @@ end
end end
end end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) @adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...) reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ) result, function (Δ)
y, ho = result y, ho = result

View File

@ -42,21 +42,6 @@ end
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
_truncate(x::AbstractArray) = Tracker.data(x)
_truncate(x::Tuple) = _truncate.(x)
"""
truncate!(rnn)
Truncates the gradient of the hidden state in recurrent layers. The value of the
state is preserved. See also `reset!`.
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
rnn.state = Tracker.data(rnn.state)
"""
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
""" """
reset!(rnn) reset!(rnn)

View File

@ -129,10 +129,6 @@ function argmax(xs...)
return onecold(xs...) return onecold(xs...)
end end
# Ambiguity hack # TODO probably still want this as a custom adjoint Zygote
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b) # onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)

View File

@ -1,9 +1,9 @@
using Juno using Juno
import Flux.Tracker: Params, gradient, data, update! import Zygote: Params, gradient
import Base.depwarn import Base.depwarn
function update!(opt, x, ) function update!(opt, x, )
update!(x, -apply!(opt, x, data())) update!(x, -apply!(opt, x, ))
end end
function update!(opt, xs::Params, gs) function update!(opt, xs::Params, gs)
@ -12,15 +12,6 @@ function update!(opt, xs::Params, gs)
end end
end end
# Added as an internal API but everyone started using it.
function _update_params!(opt, xs)
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
for x in xs
update!(opt, x, Tracker.grad(x))
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
end
end
# Callback niceties # Callback niceties
call(f, xs...) = f(xs...) call(f, xs...) = f(xs...)
runall(f) = f runall(f) = f

View File

@ -1,5 +1,5 @@
import Adapt: adapt, adapt_storage import Adapt: adapt, adapt_storage
import .Tracker: IdSet import .Zygote: IdSet
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
@ -39,7 +39,7 @@ end
function params(m) function params(m)
ps = Params() ps = Params()
prefor(p -> prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) && p isa AbstractArray{<:Real} &&
!any(p -> p === p, ps) && push!(ps, p), !any(p -> p === p, ps) && push!(ps, p),
m) m)
return ps return ps
@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m)
function mapparams(f, m) function mapparams(f, m)
mapleaves(m) do x mapleaves(m) do x
Tracker.istracked(x) ? param(f(Tracker.data(x))) : x isa Union{AbstractArray,Number} ? f(x) : x
x isa Union{AbstractArray,Number} ? f(x) :
x
end end
end end