make Maxout trainable

This commit is contained in:
Lyndon White 2019-03-25 16:02:46 +00:00
parent eeed8b24c3
commit f0cc4a328d
2 changed files with 8 additions and 0 deletions

View File

@ -167,6 +167,8 @@ function Maxout(f, n_alts)
return Maxout(over)
end
@treelike Maxout
function (mo::Maxout)(input::AbstractArray)
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end

View File

@ -53,5 +53,11 @@ using Test, Random
target = [0.5, 0.7].*input
@test mo(input) == target
end
@testset "params" begin
mo = Maxout(()->Dense(32, 64), 4)
ps = params(mo)
@test length(ps) == 8 #4 alts, each with weight and bias
end
end
end