AD skeleton
This commit is contained in:
parent
f8482ff80c
commit
d9c30db2e3
|
@ -20,6 +20,7 @@ include("core.jl")
|
|||
import .FluxCore: graph
|
||||
|
||||
include("utils.jl")
|
||||
include("grad/track.jl")
|
||||
include("params.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
args::As
|
||||
end
|
||||
|
||||
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
back!(::Void, Δ) = nothing
|
||||
|
||||
mutable struct Var{T}
|
||||
f::Call
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
Var(x::T, Δ::T) where {T} = Var(Call(nothing), x, Δ)
|
||||
Var(x::AbstractArray) = Var(x, zeros(x))
|
||||
Var(x::Number) = Var(x, zero(x))
|
||||
|
||||
function back!(x::Var, Δ)
|
||||
x.Δ .+= Δ
|
||||
back!(x.f, Δ)
|
||||
end
|
Loading…
Reference in New Issue