update capacitors
This commit is contained in:
parent
7af64398d5
commit
498a66e7b6
@ -81,21 +81,16 @@ function process_type(ex)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
macro net(ex)
|
|
||||||
isexpr(ex, :type) ? process_type(ex) :
|
|
||||||
isexpr(ex, :->, :function) ? error("@net functions not implemented") :
|
|
||||||
error("Unsupported model expression $ex")
|
|
||||||
end
|
|
||||||
|
|
||||||
function process_anon(ex)
|
function process_anon(ex)
|
||||||
args, body = process_func(ex)
|
args, body = process_func(ex)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
:(Flux.Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args))))))
|
:(Capacitor($(DataFlow.constructor(mapconst(esc, makegraph(body, args)[1])))))
|
||||||
end
|
end
|
||||||
|
|
||||||
macro ml(ex)
|
macro net(ex)
|
||||||
@capture(shortdef(ex), ((xs__,) -> body_ ) | (f_(xs__,) = body_)) ||
|
ex = shortdef(ex)
|
||||||
error("@ml requires a function definition")
|
isexpr(ex, :type) ? process_type(ex) :
|
||||||
ex = process_anon(:($(xs...,) -> $body))
|
@capture(ex, (__,) -> _) ? process_anon(ex) :
|
||||||
f == nothing ? ex : :($(esc(f)) = $ex)
|
@capture(ex, _(__) = _) ? error("@net functions not implemented") :
|
||||||
|
error("Unsupported model expression $ex")
|
||||||
end
|
end
|
||||||
|
@ -111,6 +111,7 @@ struct Capacitor <: Model
|
|||||||
graph::IVertex{Any}
|
graph::IVertex{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::Capacitor)(xs...) = interpret(reifyparams(m.graph), xs...)
|
# TODO: batching
|
||||||
|
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||||
|
|
||||||
graph(cap::Capacitor) = cap.graph
|
graph(cap::Capacitor) = cap.graph
|
||||||
|
@ -23,6 +23,10 @@ d = Affine(10, 20)
|
|||||||
|
|
||||||
@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:]
|
@test d(xs) ≈ (xs'*d.W.x + d.b.x)[1,:]
|
||||||
|
|
||||||
|
d1 = @net x -> x * d.W + d.b
|
||||||
|
|
||||||
|
@test d(xs) == d1(xs)
|
||||||
|
|
||||||
let
|
let
|
||||||
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
||||||
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||||
|
Loading…
Reference in New Issue
Block a user