From 6808a9279370d4462e7adfb07f1a1d3ac7b68341 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 23 Aug 2016 16:22:11 +0100 Subject: [PATCH] anonymous models --- src/Flux.jl | 1 + src/compiler/code.jl | 21 +++++++++++++++++++-- src/layers/anon.jl | 16 ++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 src/layers/anon.jl diff --git a/src/Flux.jl b/src/Flux.jl index 39895a14..970c6378 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,6 +19,7 @@ include("compiler/code.jl") include("cost.jl") include("activation.jl") include("layers/params.jl") +include("layers/anon.jl") include("layers/input.jl") include("layers/dense.jl") include("layers/sequence.jl") diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 8c632b8c..5a869f8e 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -2,7 +2,9 @@ import Flow: mapconst, cse -function process_func(ex, params) +export @model + +function process_func(ex, params = []) @capture(shortdef(ex), (args__,) -> body_) body = Flow.il(graphm(body)) body = mapconst(x -> x in params ? :(self.$x) : x, body) @@ -46,7 +48,7 @@ function build_forward(body, args) cse(deref_params(body)) end -function build_backward(body, x, params) +function build_backward(body, x, params = []) Δs = invert(body) back = IVertex{Any}(Flow.Do()) for param in params @@ -79,7 +81,22 @@ function process_type(ex) end |> esc end +function process_anon(ex) + args, body = process_func(ex) + layers = Set{Symbol}() + Flow.prefor(body) do v + isexpr(value(v), Symbol) && push!(layers, value(v)) + end + @assert length(args) == 1 + :(Capacitor( + ($(args...)) -> $(syntax(build_forward(body, args))), + (Δ, $(args...)) -> $(syntax(build_backward(body, args[1]))), + η -> $(map(p -> :(update!($p, η)), layers)...), + $(Flow.constructor(makegraph(body, args))))) |> esc +end + macro model(ex) isexpr(ex, :type) ? process_type(ex) : + isexpr(ex, :->, :function) ? process_anon(ex) : error("Unsupported model expression $ex") end diff --git a/src/layers/anon.jl b/src/layers/anon.jl new file mode 100644 index 00000000..9cf9f795 --- /dev/null +++ b/src/layers/anon.jl @@ -0,0 +1,16 @@ +export Capacitor + +type Capacitor <: Model + forward::Function + backward::Function + update::Function + graph::IVertex{Any} +end + +(cap::Capacitor)(args...) = cap.forward(args...) + +back!(cap::Capacitor, args...) = cap.backward(args...) + +update!(cap::Capacitor, η) = cap.update(η) + +graph(cap::Capacitor) = cap.graph