submodule
This commit is contained in:
parent
1889ccd316
commit
70393138bc
@ -15,40 +15,41 @@ back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
|||||||
|
|
||||||
back!(f, Δ) = nothing
|
back!(f, Δ) = nothing
|
||||||
|
|
||||||
struct Var{T,N,A} <: AbstractArray{T,N}
|
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||||
f::Call
|
f::Call
|
||||||
x::A
|
x::A
|
||||||
Δ::A
|
Δ::A
|
||||||
end
|
end
|
||||||
|
|
||||||
ScalarVar{T,A} = Var{T,0,A}
|
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||||
VectorVar{T,A} = Var{T,1,A}
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||||
MatrixVar{T,A} = Var{T,2,A}
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||||
|
|
||||||
Var(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
Var{eltype(A),ndims(A),A}(c, x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||||
|
|
||||||
Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x))
|
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x))
|
||||||
|
|
||||||
Var(c::Call) = Var(c, c())
|
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||||
|
|
||||||
Var(x::AbstractArray) = Var(Call(nothing), x)
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
||||||
|
|
||||||
data(x::Var) = x.x
|
track(xs) = TrackedArray(xs)
|
||||||
grad(x::Var) = x.Δ
|
data(x::TrackedArray) = x.x
|
||||||
|
grad(x::TrackedArray) = x.Δ
|
||||||
|
|
||||||
function back!(x::Var, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
x.Δ .+= Δ
|
x.Δ .+= Δ
|
||||||
back!(x.f, Δ)
|
back!(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims, Base.similar].args
|
for f in :[Base.size, Base.ndims, Base.similar].args
|
||||||
@eval @inline $f(x::Var, a...) = $f(data(x), a...)
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.showarray(io::IO, X::Var, repr::Bool = true; header = true)
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||||
if repr
|
if repr
|
||||||
print(io, "Var(")
|
print(io, "TrackedArray(")
|
||||||
Base.showarray(io, data(X), true)
|
Base.showarray(io, data(X), true)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
else
|
else
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
a::MatrixVar * b::Union{MatrixVar,AbstractMatrix} = Var(Call(*, a, b))
|
a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b))
|
||||||
a::Union{MatrixVar,AbstractMatrix} * b::MatrixVar = Var(Call(*, a, b))
|
a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b))
|
||||||
|
|
||||||
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
|
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
|
||||||
back!(a, A_mul_Bt(Δ, data(b)))
|
back!(a, A_mul_Bt(Δ, data(b)))
|
||||||
|
Loading…
Reference in New Issue
Block a user