From 09974caba07e8e3784deb61927ec48a560ec456a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sun, 20 Aug 2017 13:35:35 +0100 Subject: [PATCH] extend affine with activation --- src/Flux.jl | 2 +- src/layers/basic.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index fdcd6194..b85809da 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,7 +5,7 @@ module Flux using Juno using Lazy: @forward -export Chain, Affine, σ, softmax +export Chain, Linear, σ, softmax # Zero Flux Given diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a866cf4c..59349aba 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -15,15 +15,15 @@ Compiler.graph(s::Chain) = Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) -# Affine +# Linear -struct Affine{S,T} +struct Linear{F,S,T} + σ::F W::S b::T end -Affine(in::Integer, out::Integer; init = initn) = - Affine(track(init(out, in)), - track(init(out))) +Linear(in::Integer, out::Integer, σ = identity; init = initn) = + Linear(σ, track(init(out, in)), track(init(out))) -(a::Affine)(x) = a.W*x .+ a.b +(a::Linear)(x) = a.σ.(a.W*x .+ a.b)