2017-03-17 16:34:51 +00:00
|
|
|
export Chain, @Chain
|
2016-08-25 21:49:21 +00:00
|
|
|
|
|
|
|
type Chain <: Model
|
|
|
|
layers::Vector{Any}
|
2017-03-17 16:34:51 +00:00
|
|
|
Chain(xs...) = new([xs...])
|
2016-08-25 21:49:21 +00:00
|
|
|
end
|
|
|
|
|
2017-03-17 16:34:51 +00:00
|
|
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
2017-03-07 14:37:37 +00:00
|
|
|
@forward Chain.layers Base.start, Base.next, Base.done
|
2016-08-25 21:49:21 +00:00
|
|
|
|
|
|
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
2016-12-15 21:37:39 +00:00
|
|
|
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
|
2016-08-25 21:49:21 +00:00
|
|
|
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
|
|
|
|
|
|
|
graph(s::Chain) =
|
2016-10-25 22:10:35 +00:00
|
|
|
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
2016-09-06 17:03:39 +00:00
|
|
|
|
2017-02-28 16:42:48 +00:00
|
|
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
2017-03-17 16:34:51 +00:00
|
|
|
|
|
|
|
# Chain Macros
|
|
|
|
|
|
|
|
inferred(f, in, args...; kws...) = f(args...; kws...)
|
|
|
|
|
|
|
|
# `inferchain` allows for overriding inference behaviour for convenience.
|
|
|
|
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
|
|
|
|
# error, but for the interface we just ignore any errors and return (1, 20).
|
|
|
|
inferchain(f, xs...) = infer(f, xs...)
|
|
|
|
|
|
|
|
macro Chain(x, xs...)
|
|
|
|
inferconstructor(x) =
|
|
|
|
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
|
|
|
|
@q let
|
|
|
|
shape = nothing
|
|
|
|
c = Chain($(esc(x)))
|
|
|
|
$([:(shape = inferchain(c.layers[end], shape);
|
|
|
|
push!(c, $x)) for x in inferconstructor.(xs)]...)
|
|
|
|
c
|
|
|
|
end
|
|
|
|
end
|