extend affine with activation
This commit is contained in:
parent
8e59160df6
commit
09974caba0
@ -5,7 +5,7 @@ module Flux
|
|||||||
using Juno
|
using Juno
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Affine, σ, softmax
|
export Chain, Linear, σ, softmax
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
|
@ -15,15 +15,15 @@ Compiler.graph(s::Chain) =
|
|||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
# Affine
|
# Linear
|
||||||
|
|
||||||
struct Affine{S,T}
|
struct Linear{F,S,T}
|
||||||
|
σ::F
|
||||||
W::S
|
W::S
|
||||||
b::T
|
b::T
|
||||||
end
|
end
|
||||||
|
|
||||||
Affine(in::Integer, out::Integer; init = initn) =
|
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||||
Affine(track(init(out, in)),
|
Linear(σ, track(init(out, in)), track(init(out)))
|
||||||
track(init(out)))
|
|
||||||
|
|
||||||
(a::Affine)(x) = a.W*x .+ a.b
|
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
||||||
|
Loading…
Reference in New Issue
Block a user