generic tree functions
This commit is contained in:
parent
2ec8401d2c
commit
4bafa2b374
@ -21,6 +21,7 @@ using .Optimise
|
|||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
include("tree.jl")
|
||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
@ -22,7 +22,8 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
Optimise.children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ end
|
|||||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||||
Dense(σ, param(init(out, in)), param(init(out)))
|
Dense(σ, param(init(out, in)), param(init(out)))
|
||||||
|
|
||||||
Optimise.children(d::Dense) = (d.W, d.b)
|
treelike(Dense)
|
||||||
|
|
||||||
(a::Dense)(x) = a.σ.(a.W*x .+ a.b)
|
(a::Dense)(x) = a.σ.(a.W*x .+ a.b)
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
Optimise.children(m::Recur) = (m.cell,)
|
treelike(Recur)
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ _truncate(x::AbstractArray) = x
|
|||||||
_truncate(x::TrackedArray) = x.data
|
_truncate(x::TrackedArray) = x.data
|
||||||
_truncate(x::Tuple) = _truncate.(x)
|
_truncate(x::Tuple) = _truncate.(x)
|
||||||
|
|
||||||
truncate!(m) = foreach(truncate!, Optimise.children(m))
|
truncate!(m) = foreach(truncate!, children(m))
|
||||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
@ -44,7 +44,7 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
Optimise.children(m::RNNCell) = (m.d, m.h)
|
treelike(RNNCell)
|
||||||
|
|
||||||
function Base.show(io::IO, m::RNNCell)
|
function Base.show(io::IO, m::RNNCell)
|
||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", m.d, ")")
|
||||||
@ -82,8 +82,7 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
Optimise.children(m::LSTMCell) =
|
treelike(LSTMCell)
|
||||||
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
|
||||||
|
|
||||||
Base.show(io::IO, m::LSTMCell) =
|
Base.show(io::IO, m::LSTMCell) =
|
||||||
print(io, "LSTMCell(",
|
print(io, "LSTMCell(",
|
||||||
|
@ -3,15 +3,19 @@ module Optimise
|
|||||||
export update!, params, train!,
|
export update!, params, train!,
|
||||||
SGD
|
SGD
|
||||||
|
|
||||||
include("params.jl")
|
struct Param{T}
|
||||||
|
x::T
|
||||||
|
Δ::T
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
include("interface.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
params(ps, p::TrackedArray) = push!(ps, p)
|
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
using DataFlow: OSet
|
|
||||||
|
|
||||||
children(x) = ()
|
|
||||||
|
|
||||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
|
||||||
|
|
||||||
function params(m)
|
|
||||||
ps = OSet()
|
|
||||||
params(ps, m)
|
|
||||||
return collect(ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Param{T}
|
|
||||||
x::T
|
|
||||||
Δ::T
|
|
||||||
end
|
|
||||||
|
|
||||||
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
|
20
src/tree.jl
Normal file
20
src/tree.jl
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
children(x) = ()
|
||||||
|
mapchildren(f, x) = x
|
||||||
|
|
||||||
|
function treelike(T, fs = fieldnames(T))
|
||||||
|
@eval begin
|
||||||
|
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
|
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
using DataFlow: OSet
|
||||||
|
|
||||||
|
params(ps, p::AbstractArray) = push!(ps, p)
|
||||||
|
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||||
|
|
||||||
|
function params(m)
|
||||||
|
ps = OSet()
|
||||||
|
params(ps, m)
|
||||||
|
return collect(ps)
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user