diff --git a/src/grad/lib.jl b/src/grad/lib.jl new file mode 100644 index 00000000..a1e35c67 --- /dev/null +++ b/src/grad/lib.jl @@ -0,0 +1,10 @@ +import Base: * + +a::MatrixVar * b::Union{MatrixVar,AbstractMatrix} = Var(Call(*, a, b)) +a::Union{MatrixVar,AbstractMatrix} * b::MatrixVar = Var(Call(*, a, b)) + +function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray) + back!(a, A_mul_Bt(Δ, data(b))) + back!(b, At_mul_B(data(a), Δ)) + return +end diff --git a/src/grad/track.jl b/src/grad/track.jl index 5c378585..5938e806 100644 --- a/src/grad/track.jl +++ b/src/grad/track.jl @@ -1,3 +1,5 @@ +data(x) = x + struct Call{F,As<:Tuple} func::F args::As @@ -5,20 +7,39 @@ end Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) -back!(c::Call, Δ) = back!(c.func, Δ, c.args...) -back!(::Void, Δ) = nothing +(c::Call)() = c.func(data.(c.args)...) -mutable struct Var{T} +back!(c::Call, Δ) = back!(c.func, Δ, c.args...) + +back!(f, Δ) = nothing + +struct Var{T,N,A} <: AbstractArray{T,N} f::Call - x::T - Δ::T + x::A + Δ::A end -Var(x::T, Δ::T) where {T} = Var(Call(nothing), x, Δ) -Var(x::AbstractArray) = Var(x, zeros(x)) -Var(x::Number) = Var(x, zero(x)) +ScalarVar{T,A} = Var{T,0,A} +VectorVar{T,A} = Var{T,1,A} +MatrixVar{T,A} = Var{T,2,A} + +Var(c::Call, x::A, Δ::A) where A <: AbstractArray = + Var{eltype(A),ndims(A),A}(c, x, Δ) + +Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x)) + +Var(c::Call) = Var(c, c()) + +Var(x::AbstractArray) = Var(Call(nothing), x) + +data(x::Var) = x.x +grad(x::Var) = x.Δ function back!(x::Var, Δ) x.Δ .+= Δ back!(x.f, Δ) end + +for f in :[Base.size, Base.getindex].args + @eval @inline $f(x::Var, a...) = $f(data(x), a...) +end