diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index a7b98129..f454910b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -18,11 +18,16 @@ type ChainSeq ChainSeq(xs...) = new([xs...]) end +@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! +@forward ChainSeq.layers Base.start, Base.next, Base.done + Optimise.children(c::ChainSeq) = c.layers (c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers) (c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data]) +Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...) + function Base.show(io::IO, c::ChainSeq) print(io, "ChainSeq(") join(io, c.layers, ", ") @@ -44,9 +49,18 @@ function (m::Recur)(xs...) return y end +(m::Recur)(s::Seq) = Seq([m(x) for x in s.data]) + +Optimise.children(m::Recur) = (m.cell,) + Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") -(m::Recur)(s::Seq) = Seq([m(x) for x in s.data]) +_truncate(x::AbstractArray) = x +_truncate(x::TrackedArray) = x.data +_truncate(x::Tuple) = _truncate.(x) + +truncate!(m) = foreach(truncate!, Optimise.children(m)) +truncate!(m::Recur) = (m.state = _truncate(m.state)) # Vanilla RNN @@ -65,6 +79,8 @@ end hidden(m::RNNCell) = m.h +Optimise.children(m::RNNCell) = (m.d, m.h) + function Base.show(io::IO, m::RNNCell) print(io, "RNNCell(", m.d, ")") end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index d0e33941..601707af 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -2,7 +2,7 @@ module Tracker using Base: RefValue -export track, back! +export TrackedArray, track, back! data(x) = x istracked(x) = false