fix/test native multi outputs
This commit is contained in:
parent
982c105c57
commit
d9910070b4
|
@ -72,7 +72,7 @@ function process_type(ex)
|
|||
$(build_type(T, params))
|
||||
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
||||
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
|
||||
($self::$(esc(T)))($(args...)) = first($self(map(batchone, ($(args...),))...))
|
||||
($self::$(esc(T)))($(args...)) = unbatchone($self(map(batchone, ($(args...),))...))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
|
|
|
@ -37,6 +37,8 @@ function rebatch(xs)
|
|||
Batch{T,B}(xs)
|
||||
end
|
||||
|
||||
rebatch(xs::Tuple) = map(rebatch, xs)
|
||||
|
||||
convertel(T::Type, xs::Batch) =
|
||||
isa(eltype(eltype(xs)), T) ? xs :
|
||||
Batch(map(x->convertel(T, x), xs))
|
||||
|
|
|
@ -7,6 +7,15 @@
|
|||
end
|
||||
end
|
||||
|
||||
@net type Multi
|
||||
W
|
||||
V
|
||||
x -> (x*W, x*V)
|
||||
end
|
||||
|
||||
Multi(in::Integer, out::Integer) =
|
||||
Multi(randn(in, out), randn(in, out))
|
||||
|
||||
@testset "Basics" begin
|
||||
|
||||
xs = randn(10)
|
||||
|
@ -36,4 +45,9 @@ let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|||
@test e.trace[end-1].func == Symbol("Flux.Affine")
|
||||
end
|
||||
|
||||
let m = Multi(10, 15)
|
||||
x = rand(10)
|
||||
@test all(isapprox.(m(x), (m.W.x' * x, m.V.x' * x)))
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue