extend affine with activation
This commit is contained in:
parent
8e59160df6
commit
09974caba0
@ -5,7 +5,7 @@ module Flux
|
||||
using Juno
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Affine, σ, softmax
|
||||
export Chain, Linear, σ, softmax
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
|
@ -15,15 +15,15 @@ Compiler.graph(s::Chain) =
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
# Affine
|
||||
# Linear
|
||||
|
||||
struct Affine{S,T}
|
||||
struct Linear{F,S,T}
|
||||
σ::F
|
||||
W::S
|
||||
b::T
|
||||
end
|
||||
|
||||
Affine(in::Integer, out::Integer; init = initn) =
|
||||
Affine(track(init(out, in)),
|
||||
track(init(out)))
|
||||
Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Linear(σ, track(init(out, in)), track(init(out)))
|
||||
|
||||
(a::Affine)(x) = a.W*x .+ a.b
|
||||
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
||||
|
Loading…
Reference in New Issue
Block a user