42 lines
898 B
Julia
42 lines
898 B
Julia
![]() |
export Dense
|
||
|
|
||
|
type Dense <: Model
|
||
|
W::Matrix{Float32}
|
||
|
b::Vector{Float32}
|
||
|
∇W::Matrix{Float32}
|
||
|
∇b::Vector{Float32}
|
||
|
|
||
|
in::Vector{Float32}
|
||
|
out::Vector{Float32}
|
||
|
∇in::Vector{Float32}
|
||
|
end
|
||
|
|
||
|
Dense(in::Integer, out::Integer) =
|
||
|
Dense(randn(out, in), randn(out),
|
||
|
zeros(out, in), zeros(out),
|
||
|
zeros(in), zeros(out), zeros(in))
|
||
|
|
||
|
Dense(out::Integer) = Init(in -> Dense(in[1], out))
|
||
|
|
||
|
function (l::Dense)(x)
|
||
|
l.in = x
|
||
|
A_mul_B!(l.out, l.W, x)
|
||
|
map!(+, l.out, l.out, l.b)
|
||
|
end
|
||
|
|
||
|
function back!(l::Dense, ∇)
|
||
|
map!(+, l.∇b, l.∇b, ∇)
|
||
|
# l.∇W += ∇ * l.in'
|
||
|
BLAS.gemm!('N', 'T', eltype(∇)(1), ∇, l.in, eltype(∇)(1), l.∇W)
|
||
|
At_mul_B!(l.∇in, l.W, ∇)
|
||
|
end
|
||
|
|
||
|
function update!(l::Dense, η)
|
||
|
map!((x, ∇x) -> x - η*∇x, l.W, l.W, l.∇W)
|
||
|
map!((x, ∇x) -> x - η*∇x, l.b, l.b, l.∇b)
|
||
|
fill!(l.∇W, 0)
|
||
|
fill!(l.∇b, 0)
|
||
|
end
|
||
|
|
||
|
shape(d::Dense) = size(d.b)
|