batching for basic layers
This commit is contained in:
parent
568b8d7e48
commit
56c5784d83
@ -54,7 +54,7 @@ function process_type(ex)
|
|||||||
self = esc(:self)
|
self = esc(:self)
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:(self::$T)))($(args...),) = interpret(reifyparams(graph($self)), $(args...))
|
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||||
nothing
|
nothing
|
||||||
|
@ -24,3 +24,9 @@ function imap(cb, ctx, ::typeof(map), f, xs...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
imap(f, args...) = f(args...)
|
imap(f, args...) = f(args...)
|
||||||
|
|
||||||
|
function interpmodel(m, args::Batch...)
|
||||||
|
rebatch(interpret(reifyparams(graph(m)), map(rawbatch, args)...))
|
||||||
|
end
|
||||||
|
|
||||||
|
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
xs = randn(10)' # TODO: batching semantics
|
xs = randn(10)
|
||||||
|
|
||||||
d = Affine(10, 20)
|
d = Affine(10, 20)
|
||||||
|
|
||||||
@test d(xs) == xs*d.W.x + d.b.x
|
@test d(xs) == (xs'*d.W.x + d.b.x)[1,:]
|
||||||
|
|
||||||
let
|
let
|
||||||
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
||||||
|
Loading…
Reference in New Issue
Block a user