rnn stuff
This commit is contained in:
parent
a93c440c1e
commit
e837bb0745
@ -18,11 +18,16 @@ type ChainSeq
|
||||
ChainSeq(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward ChainSeq.layers Base.start, Base.next, Base.done
|
||||
|
||||
Optimise.children(c::ChainSeq) = c.layers
|
||||
|
||||
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
||||
|
||||
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
function Base.show(io::IO, c::ChainSeq)
|
||||
print(io, "ChainSeq(")
|
||||
join(io, c.layers, ", ")
|
||||
@ -44,9 +49,18 @@ function (m::Recur)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||
|
||||
Optimise.children(m::Recur) = (m.cell,)
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||
_truncate(x::AbstractArray) = x
|
||||
_truncate(x::TrackedArray) = x.data
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
truncate!(m) = foreach(truncate!, Optimise.children(m))
|
||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||
|
||||
# Vanilla RNN
|
||||
|
||||
@ -65,6 +79,8 @@ end
|
||||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
Optimise.children(m::RNNCell) = (m.d, m.h)
|
||||
|
||||
function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
end
|
||||
|
@ -2,7 +2,7 @@ module Tracker
|
||||
|
||||
using Base: RefValue
|
||||
|
||||
export track, back!
|
||||
export TrackedArray, track, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
|
Loading…
Reference in New Issue
Block a user