batching for basic layers

This commit is contained in:
Mike J Innes 2017-01-24 17:23:42 +05:30
parent 568b8d7e48
commit 56c5784d83
3 changed files with 9 additions and 3 deletions

View File

@ -54,7 +54,7 @@ function process_type(ex)
self = esc(:self)
quote
$(build_type(T, params))
$(esc(:(self::$T)))($(args...),) = interpret(reifyparams(graph($self)), $(args...))
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing

View File

@ -24,3 +24,9 @@ function imap(cb, ctx, ::typeof(map), f, xs...)
end
imap(f, args...) = f(args...)
function interpmodel(m, args::Batch...)
rebatch(interpret(reifyparams(graph(m)), map(rawbatch, args)...))
end
interpmodel(m, args...) = unbatchone(interpmodel(m, batchone(args)...))

View File

@ -1,8 +1,8 @@
xs = randn(10)' # TODO: batching semantics
xs = randn(10)
d = Affine(10, 20)
@test d(xs) == xs*d.W.x + d.b.x
@test d(xs) == (xs'*d.W.x + d.b.x)[1,:]
let
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))