rnn stuff

This commit is contained in:
Mike J Innes 2017-09-07 00:05:02 -04:00
parent a93c440c1e
commit e837bb0745
2 changed files with 18 additions and 2 deletions

View File

@ -18,11 +18,16 @@ type ChainSeq
ChainSeq(xs...) = new([xs...])
end
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward ChainSeq.layers Base.start, Base.next, Base.done
Optimise.children(c::ChainSeq) = c.layers
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
function Base.show(io::IO, c::ChainSeq)
print(io, "ChainSeq(")
join(io, c.layers, ", ")
@ -44,9 +49,18 @@ function (m::Recur)(xs...)
return y
end
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
Optimise.children(m::Recur) = (m.cell,)
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
_truncate(x::AbstractArray) = x
_truncate(x::TrackedArray) = x.data
_truncate(x::Tuple) = _truncate.(x)
truncate!(m) = foreach(truncate!, Optimise.children(m))
truncate!(m::Recur) = (m.state = _truncate(m.state))
# Vanilla RNN
@ -65,6 +79,8 @@ end
hidden(m::RNNCell) = m.h
Optimise.children(m::RNNCell) = (m.d, m.h)
function Base.show(io::IO, m::RNNCell)
print(io, "RNNCell(", m.d, ")")
end

View File

@ -2,7 +2,7 @@ module Tracker
using Base: RefValue
export track, back!
export TrackedArray, track, back!
data(x) = x
istracked(x) = false