diff --git a/src/Flux.jl b/src/Flux.jl index f56e5350..641a1af3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -13,11 +13,13 @@ abstract Activation <: Model back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))") update!(m::Model, η) = m -include("utils.jl") +include("rt/diff.jl") + include("cost.jl") include("activation.jl") include("layers/input.jl") include("layers/dense.jl") include("layers/sequence.jl") +include("utils.jl") end # module diff --git a/src/activation.jl b/src/activation.jl index f0c0a82d..a9832c7e 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -3,6 +3,8 @@ export Sigmoid σ(x) = 1/(1+exp(-x)) σ′(x) = σ(x)*(1-σ(x)) +∇₁(::typeof(σ)) = σ′ + type Sigmoid <: Activation in::Vector{Float32} out::Vector{Float32} diff --git a/src/rt/diff.jl b/src/rt/diff.jl new file mode 100644 index 00000000..90a162c0 --- /dev/null +++ b/src/rt/diff.jl @@ -0,0 +1,35 @@ +import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax + +vertex(a...) = IVertex{Any}(a...) + +∇(f, a::IVertex) = + (v(∇₁(f), a),) + +∇(::typeof(+), a::IVertex, b::IVertex) = + v(1), v(1) + +∇(::typeof(-), a::IVertex, b::IVertex) = + v(1), v(-1) + +∇(::typeof(*), a::IVertex, b::IVertex) = + v(transpose, b), v(transpose, a) + +function ∇v(v::IVertex, chain::Vector{IVertex}, out = d()) + if isconstant(v) + @assert !haskey(out, value(v)) + out[value(v)] = length(chain) == 1 ? + first(chain) : + foldl((x, y) -> vertex(*, x, y), chain) + else + ∇s = ∇(value(v), inputs(v)...) + for (v′, ∇′) in zip(inputs(v), ∇s) + ∇v(v′, (value(∇′) ≠ 1 ? push!(copy(chain), ∇′) : chain), out) + end + end + return out +end + +∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain)) + +∇v(v::Vertex, ∂::Vertex) = ∇v(v, [∂]) +