AD skeleton

This commit is contained in:
Mike J Innes 2017-08-18 16:50:27 +01:00
parent f8482ff80c
commit d9c30db2e3
2 changed files with 25 additions and 0 deletions

View File

@ -20,6 +20,7 @@ include("core.jl")
import .FluxCore: graph
include("utils.jl")
include("grad/track.jl")
include("params.jl")
include("compiler/code.jl")

24
src/grad/track.jl Normal file
View File

@ -0,0 +1,24 @@
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(::Void, Δ) = nothing
mutable struct Var{T}
f::Call
x::T
Δ::T
end
Var(x::T, Δ::T) where {T} = Var(Call(nothing), x, Δ)
Var(x::AbstractArray) = Var(x, zeros(x))
Var(x::Number) = Var(x, zero(x))
function back!(x::Var, Δ)
x.Δ .+= Δ
back!(x.f, Δ)
end