batching for basic layers
This commit is contained in:
parent
568b8d7e48
commit
56c5784d83
|
@ -54,7 +54,7 @@ function process_type(ex)
|
|||
self = esc(:self)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(esc(:(self::$T)))($(args...),) = interpret(reifyparams(graph($self)), $(args...))
|
||||
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
|
|
|
@ -24,3 +24,9 @@ function imap(cb, ctx, ::typeof(map), f, xs...)
|
|||
end
|
||||
|
||||
imap(f, args...) = f(args...)
|
||||
|
||||
function interpmodel(m, args::Batch...)
|
||||
rebatch(interpret(reifyparams(graph(m)), map(rawbatch, args)...))
|
||||
end
|
||||
|
||||
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
xs = randn(10)' # TODO: batching semantics
|
||||
xs = randn(10)
|
||||
|
||||
d = Affine(10, 20)
|
||||
|
||||
@test d(xs) == xs*d.W.x + d.b.x
|
||||
@test d(xs) == (xs'*d.W.x + d.b.x)[1,:]
|
||||
|
||||
let
|
||||
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
||||
|
|
Loading…
Reference in New Issue