rnn stuff
This commit is contained in:
parent
a93c440c1e
commit
e837bb0745
@ -18,11 +18,16 @@ type ChainSeq
|
|||||||
ChainSeq(xs...) = new([xs...])
|
ChainSeq(xs...) = new([xs...])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@forward ChainSeq.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
|
@forward ChainSeq.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
Optimise.children(c::ChainSeq) = c.layers
|
Optimise.children(c::ChainSeq) = c.layers
|
||||||
|
|
||||||
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
(c::ChainSeq)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||||
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
(c::ChainSeq)(s::Seq) = Seq([c(x) for x in s.data])
|
||||||
|
|
||||||
|
Base.getindex(c::ChainSeq, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
function Base.show(io::IO, c::ChainSeq)
|
function Base.show(io::IO, c::ChainSeq)
|
||||||
print(io, "ChainSeq(")
|
print(io, "ChainSeq(")
|
||||||
join(io, c.layers, ", ")
|
join(io, c.layers, ", ")
|
||||||
@ -44,9 +49,18 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
||||||
|
|
||||||
|
Optimise.children(m::Recur) = (m.cell,)
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
(m::Recur)(s::Seq) = Seq([m(x) for x in s.data])
|
_truncate(x::AbstractArray) = x
|
||||||
|
_truncate(x::TrackedArray) = x.data
|
||||||
|
_truncate(x::Tuple) = _truncate.(x)
|
||||||
|
|
||||||
|
truncate!(m) = foreach(truncate!, Optimise.children(m))
|
||||||
|
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
|
|
||||||
@ -65,6 +79,8 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
|
Optimise.children(m::RNNCell) = (m.d, m.h)
|
||||||
|
|
||||||
function Base.show(io::IO, m::RNNCell)
|
function Base.show(io::IO, m::RNNCell)
|
||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", m.d, ")")
|
||||||
end
|
end
|
||||||
|
@ -2,7 +2,7 @@ module Tracker
|
|||||||
|
|
||||||
using Base: RefValue
|
using Base: RefValue
|
||||||
|
|
||||||
export track, back!
|
export TrackedArray, track, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
istracked(x) = false
|
istracked(x) = false
|
||||||
|
Loading…
Reference in New Issue
Block a user