matmul
This commit is contained in:
parent
8acc09ccf0
commit
f1dda12a54
10
src/grad/lib.jl
Normal file
10
src/grad/lib.jl
Normal file
@ -0,0 +1,10 @@
|
||||
import Base: *
|
||||
|
||||
a::MatrixVar * b::Union{MatrixVar,AbstractMatrix} = Var(Call(*, a, b))
|
||||
a::Union{MatrixVar,AbstractMatrix} * b::MatrixVar = Var(Call(*, a, b))
|
||||
|
||||
function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray)
|
||||
back!(a, A_mul_Bt(Δ, data(b)))
|
||||
back!(b, At_mul_B(data(a), Δ))
|
||||
return
|
||||
end
|
@ -1,3 +1,5 @@
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
args::As
|
||||
@ -5,20 +7,39 @@ end
|
||||
|
||||
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
back!(::Void, Δ) = nothing
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct Var{T}
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
|
||||
back!(f, Δ) = nothing
|
||||
|
||||
struct Var{T,N,A} <: AbstractArray{T,N}
|
||||
f::Call
|
||||
x::T
|
||||
Δ::T
|
||||
x::A
|
||||
Δ::A
|
||||
end
|
||||
|
||||
Var(x::T, Δ::T) where {T} = Var(Call(nothing), x, Δ)
|
||||
Var(x::AbstractArray) = Var(x, zeros(x))
|
||||
Var(x::Number) = Var(x, zero(x))
|
||||
ScalarVar{T,A} = Var{T,0,A}
|
||||
VectorVar{T,A} = Var{T,1,A}
|
||||
MatrixVar{T,A} = Var{T,2,A}
|
||||
|
||||
Var(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
Var{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x))
|
||||
|
||||
Var(c::Call) = Var(c, c())
|
||||
|
||||
Var(x::AbstractArray) = Var(Call(nothing), x)
|
||||
|
||||
data(x::Var) = x.x
|
||||
grad(x::Var) = x.Δ
|
||||
|
||||
function back!(x::Var, Δ)
|
||||
x.Δ .+= Δ
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
for f in :[Base.size, Base.getindex].args
|
||||
@eval @inline $f(x::Var, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user