submodule

This commit is contained in:
Mike J Innes 2017-08-19 11:11:25 +01:00
parent 1889ccd316
commit 70393138bc
2 changed files with 18 additions and 17 deletions

View File

@ -15,40 +15,41 @@ back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(f, Δ) = nothing back!(f, Δ) = nothing
struct Var{T,N,A} <: AbstractArray{T,N} struct TrackedArray{T,N,A} <: AbstractArray{T,N}
f::Call f::Call
x::A x::A
Δ::A Δ::A
end end
ScalarVar{T,A} = Var{T,0,A} TrackedScalar{T,A} = TrackedArray{T,0,A}
VectorVar{T,A} = Var{T,1,A} TrackedVector{T,A} = TrackedArray{T,1,A}
MatrixVar{T,A} = Var{T,2,A} TrackedMatrix{T,A} = TrackedArray{T,2,A}
Var(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
Var{eltype(A),ndims(A),A}(c, x, Δ) TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x)) TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x))
Var(c::Call) = Var(c, c()) TrackedArray(c::Call) = TrackedArray(c, c())
Var(x::AbstractArray) = Var(Call(nothing), x) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
data(x::Var) = x.x track(xs) = TrackedArray(xs)
grad(x::Var) = x.Δ data(x::TrackedArray) = x.x
grad(x::TrackedArray) = x.Δ
function back!(x::Var, Δ) function back!(x::TrackedArray, Δ)
x.Δ .+= Δ x.Δ .+= Δ
back!(x.f, Δ) back!(x.f, Δ)
end end
for f in :[Base.size, Base.ndims, Base.similar].args for f in :[Base.size, Base.ndims, Base.similar].args
@eval @inline $f(x::Var, a...) = $f(data(x), a...) @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end end
function Base.showarray(io::IO, X::Var, repr::Bool = true; header = true) function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr if repr
print(io, "Var(") print(io, "TrackedArray(")
Base.showarray(io, data(X), true) Base.showarray(io, data(X), true)
print(io, ")") print(io, ")")
else else

View File

@ -1,7 +1,7 @@
import Base: * import Base: *
a::MatrixVar * b::Union{MatrixVar,AbstractMatrix} = Var(Call(*, a, b)) a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b))
a::Union{MatrixVar,AbstractMatrix} * b::MatrixVar = Var(Call(*, a, b)) a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b))
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray) function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
back!(a, A_mul_Bt(Δ, data(b))) back!(a, A_mul_Bt(Δ, data(b)))