make Maxout trainable
This commit is contained in:
parent
eeed8b24c3
commit
f0cc4a328d
@ -167,6 +167,8 @@ function Maxout(f, n_alts)
|
|||||||
return Maxout(over)
|
return Maxout(over)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@treelike Maxout
|
||||||
|
|
||||||
function (mo::Maxout)(input::AbstractArray)
|
function (mo::Maxout)(input::AbstractArray)
|
||||||
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
|
||||||
end
|
end
|
||||||
|
@ -53,5 +53,11 @@ using Test, Random
|
|||||||
target = [0.5, 0.7].*input
|
target = [0.5, 0.7].*input
|
||||||
@test mo(input) == target
|
@test mo(input) == target
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "params" begin
|
||||||
|
mo = Maxout(()->Dense(32, 64), 4)
|
||||||
|
ps = params(mo)
|
||||||
|
@test length(ps) == 8 #4 alts, each with weight and bias
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user