generic tree functions
This commit is contained in:
parent
2ec8401d2c
commit
4bafa2b374
@ -21,6 +21,7 @@ using .Optimise
|
||||
|
||||
include("utils.jl")
|
||||
include("onehot.jl")
|
||||
include("tree.jl")
|
||||
|
||||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
|
@ -22,7 +22,8 @@ end
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
|
||||
Optimise.children(c::Chain) = c.layers
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
|
||||
@ -53,7 +54,7 @@ end
|
||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Dense(σ, param(init(out, in)), param(init(out)))
|
||||
|
||||
Optimise.children(d::Dense) = (d.W, d.b)
|
||||
treelike(Dense)
|
||||
|
||||
(a::Dense)(x) = a.σ.(a.W*x .+ a.b)
|
||||
|
||||
|
@ -16,7 +16,7 @@ function (m::Recur)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
Optimise.children(m::Recur) = (m.cell,)
|
||||
treelike(Recur)
|
||||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
@ -24,7 +24,7 @@ _truncate(x::AbstractArray) = x
|
||||
_truncate(x::TrackedArray) = x.data
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
truncate!(m) = foreach(truncate!, Optimise.children(m))
|
||||
truncate!(m) = foreach(truncate!, children(m))
|
||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||
|
||||
# Vanilla RNN
|
||||
@ -44,7 +44,7 @@ end
|
||||
|
||||
hidden(m::RNNCell) = m.h
|
||||
|
||||
Optimise.children(m::RNNCell) = (m.d, m.h)
|
||||
treelike(RNNCell)
|
||||
|
||||
function Base.show(io::IO, m::RNNCell)
|
||||
print(io, "RNNCell(", m.d, ")")
|
||||
@ -82,8 +82,7 @@ end
|
||||
|
||||
hidden(m::LSTMCell) = (m.h, m.c)
|
||||
|
||||
Optimise.children(m::LSTMCell) =
|
||||
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
||||
treelike(LSTMCell)
|
||||
|
||||
Base.show(io::IO, m::LSTMCell) =
|
||||
print(io, "LSTMCell(",
|
||||
|
@ -3,15 +3,19 @@ module Optimise
|
||||
export update!, params, train!,
|
||||
SGD
|
||||
|
||||
include("params.jl")
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
include("train.jl")
|
||||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
params(ps, p::TrackedArray) = push!(ps, p)
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||
|
||||
end
|
||||
|
@ -1,18 +0,0 @@
|
||||
using DataFlow: OSet
|
||||
|
||||
children(x) = ()
|
||||
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
|
||||
function params(m)
|
||||
ps = OSet()
|
||||
params(ps, m)
|
||||
return collect(ps)
|
||||
end
|
||||
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δ::T
|
||||
end
|
||||
|
||||
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
20
src/tree.jl
Normal file
20
src/tree.jl
Normal file
@ -0,0 +1,20 @@
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
||||
function treelike(T, fs = fieldnames(T))
|
||||
@eval begin
|
||||
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||
end
|
||||
end
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
params(ps, p::AbstractArray) = push!(ps, p)
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
|
||||
function params(m)
|
||||
ps = OSet()
|
||||
params(ps, m)
|
||||
return collect(ps)
|
||||
end
|
Loading…
Reference in New Issue
Block a user