AD skeleton
This commit is contained in:
parent
f8482ff80c
commit
d9c30db2e3
@ -20,6 +20,7 @@ include("core.jl")
|
|||||||
import .FluxCore: graph
|
import .FluxCore: graph
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("grad/track.jl")
|
||||||
include("params.jl")
|
include("params.jl")
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
|
24
src/grad/track.jl
Normal file
24
src/grad/track.jl
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
struct Call{F,As<:Tuple}
|
||||||
|
func::F
|
||||||
|
args::As
|
||||||
|
end
|
||||||
|
|
||||||
|
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||||
|
|
||||||
|
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||||
|
back!(::Void, Δ) = nothing
|
||||||
|
|
||||||
|
mutable struct Var{T}
|
||||||
|
f::Call
|
||||||
|
x::T
|
||||||
|
Δ::T
|
||||||
|
end
|
||||||
|
|
||||||
|
Var(x::T, Δ::T) where {T} = Var(Call(nothing), x, Δ)
|
||||||
|
Var(x::AbstractArray) = Var(x, zeros(x))
|
||||||
|
Var(x::Number) = Var(x, zero(x))
|
||||||
|
|
||||||
|
function back!(x::Var, Δ)
|
||||||
|
x.Δ .+= Δ
|
||||||
|
back!(x.f, Δ)
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user