anonymous models
This commit is contained in:
parent
51d14cef20
commit
6808a92793
@ -19,6 +19,7 @@ include("compiler/code.jl")
|
|||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
include("layers/params.jl")
|
include("layers/params.jl")
|
||||||
|
include("layers/anon.jl")
|
||||||
include("layers/input.jl")
|
include("layers/input.jl")
|
||||||
include("layers/dense.jl")
|
include("layers/dense.jl")
|
||||||
include("layers/sequence.jl")
|
include("layers/sequence.jl")
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
import Flow: mapconst, cse
|
import Flow: mapconst, cse
|
||||||
|
|
||||||
function process_func(ex, params)
|
export @model
|
||||||
|
|
||||||
|
function process_func(ex, params = [])
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = Flow.il(graphm(body))
|
body = Flow.il(graphm(body))
|
||||||
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||||
@ -46,7 +48,7 @@ function build_forward(body, args)
|
|||||||
cse(deref_params(body))
|
cse(deref_params(body))
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_backward(body, x, params)
|
function build_backward(body, x, params = [])
|
||||||
Δs = invert(body)
|
Δs = invert(body)
|
||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(Flow.Do())
|
||||||
for param in params
|
for param in params
|
||||||
@ -79,7 +81,22 @@ function process_type(ex)
|
|||||||
end |> esc
|
end |> esc
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function process_anon(ex)
|
||||||
|
args, body = process_func(ex)
|
||||||
|
layers = Set{Symbol}()
|
||||||
|
Flow.prefor(body) do v
|
||||||
|
isexpr(value(v), Symbol) && push!(layers, value(v))
|
||||||
|
end
|
||||||
|
@assert length(args) == 1
|
||||||
|
:(Capacitor(
|
||||||
|
($(args...)) -> $(syntax(build_forward(body, args))),
|
||||||
|
(Δ, $(args...)) -> $(syntax(build_backward(body, args[1]))),
|
||||||
|
η -> $(map(p -> :(update!($p, η)), layers)...),
|
||||||
|
$(Flow.constructor(makegraph(body, args))))) |> esc
|
||||||
|
end
|
||||||
|
|
||||||
macro model(ex)
|
macro model(ex)
|
||||||
isexpr(ex, :type) ? process_type(ex) :
|
isexpr(ex, :type) ? process_type(ex) :
|
||||||
|
isexpr(ex, :->, :function) ? process_anon(ex) :
|
||||||
error("Unsupported model expression $ex")
|
error("Unsupported model expression $ex")
|
||||||
end
|
end
|
||||||
|
16
src/layers/anon.jl
Normal file
16
src/layers/anon.jl
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
export Capacitor
|
||||||
|
|
||||||
|
type Capacitor <: Model
|
||||||
|
forward::Function
|
||||||
|
backward::Function
|
||||||
|
update::Function
|
||||||
|
graph::IVertex{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
(cap::Capacitor)(args...) = cap.forward(args...)
|
||||||
|
|
||||||
|
back!(cap::Capacitor, args...) = cap.backward(args...)
|
||||||
|
|
||||||
|
update!(cap::Capacitor, η) = cap.update(η)
|
||||||
|
|
||||||
|
graph(cap::Capacitor) = cap.graph
|
Loading…
Reference in New Issue
Block a user