fix/test native multi outputs

This commit is contained in:
Mike J Innes 2017-03-06 16:12:03 +00:00
parent 982c105c57
commit d9910070b4
3 changed files with 17 additions and 1 deletions

View File

@ -72,7 +72,7 @@ function process_type(ex)
$(build_type(T, params))
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
($self::$(esc(T)))($(args...)) = first($self(map(batchone, ($(args...),))...))
($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing

View File

@ -37,6 +37,8 @@ function rebatch(xs)
Batch{T,B}(xs)
end
rebatch(xs::Tuple) = map(rebatch, xs)
convertel(T::Type, xs::Batch) =
isa(eltype(eltype(xs)), T) ? xs :
Batch(map(x->convertel(T, x), xs))

View File

@ -7,6 +7,15 @@
end
end
@net type Multi
W
V
x -> (x*W, x*V)
end
Multi(in::Integer, out::Integer) =
Multi(randn(in, out), randn(in, out))
@testset "Basics" begin
xs = randn(10)
@ -36,4 +45,9 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
@test e.trace[end-1].func == Symbol("Flux.Affine")
end
let m = Multi(10, 15)
x = rand(10)
@test all(isapprox.(m(x), (m.W.x' * x, m.V.x' * x)))
end
end