Flux.jl/src/grad/track.jl

57 lines
1.1 KiB
Julia
Raw Normal View History

2017-08-19 09:14:50 +00:00
data(x) = x
2017-08-18 15:50:27 +00:00
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
2017-08-19 09:14:50 +00:00
(c::Call)() = c.func(data.(c.args)...)
2017-08-18 15:50:27 +00:00
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
2017-08-19 09:14:50 +00:00
back!(f, Δ) = nothing
struct Var{T,N,A} <: AbstractArray{T,N}
2017-08-18 15:50:27 +00:00
f::Call
2017-08-19 09:14:50 +00:00
x::A
Δ::A
2017-08-18 15:50:27 +00:00
end
2017-08-19 09:14:50 +00:00
ScalarVar{T,A} = Var{T,0,A}
VectorVar{T,A} = Var{T,1,A}
MatrixVar{T,A} = Var{T,2,A}
Var(c::Call, x::A, Δ::A) where A <: AbstractArray =
Var{eltype(A),ndims(A),A}(c, x, Δ)
Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x))
Var(c::Call) = Var(c, c())
Var(x::AbstractArray) = Var(Call(nothing), x)
data(x::Var) = x.x
grad(x::Var) = x.Δ
2017-08-18 15:50:27 +00:00
function back!(x::Var, Δ)
x.Δ .+= Δ
back!(x.f, Δ)
end
2017-08-19 09:14:50 +00:00
2017-08-19 10:00:55 +00:00
for f in :[Base.size, Base.ndims, Base.similar].args
2017-08-19 09:14:50 +00:00
@eval @inline $f(x::Var, a...) = $f(data(x), a...)
end
2017-08-19 10:00:55 +00:00
function Base.showarray(io::IO, X::Var, repr::Bool = true; header = true)
if repr
print(io, "Var(")
Base.showarray(io, data(X), true)
print(io, ")")
else
println(io, summary(X), ":")
Base.showarray(io, data(X), false, header = false)
end
end