autodiff stuff
This commit is contained in:
parent
c082695425
commit
eaa77cc5a6
@ -13,11 +13,13 @@ abstract Activation <: Model
|
|||||||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||||
update!(m::Model, η) = m
|
update!(m::Model, η) = m
|
||||||
|
|
||||||
include("utils.jl")
|
include("rt/diff.jl")
|
||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
include("layers/input.jl")
|
include("layers/input.jl")
|
||||||
include("layers/dense.jl")
|
include("layers/dense.jl")
|
||||||
include("layers/sequence.jl")
|
include("layers/sequence.jl")
|
||||||
|
include("utils.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -3,6 +3,8 @@ export Sigmoid
|
|||||||
σ(x) = 1/(1+exp(-x))
|
σ(x) = 1/(1+exp(-x))
|
||||||
σ′(x) = σ(x)*(1-σ(x))
|
σ′(x) = σ(x)*(1-σ(x))
|
||||||
|
|
||||||
|
∇₁(::typeof(σ)) = σ′
|
||||||
|
|
||||||
type Sigmoid <: Activation
|
type Sigmoid <: Activation
|
||||||
in::Vector{Float32}
|
in::Vector{Float32}
|
||||||
out::Vector{Float32}
|
out::Vector{Float32}
|
||||||
|
35
src/rt/diff.jl
Normal file
35
src/rt/diff.jl
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax
|
||||||
|
|
||||||
|
vertex(a...) = IVertex{Any}(a...)
|
||||||
|
|
||||||
|
∇(f, a::IVertex) =
|
||||||
|
(v(∇₁(f), a),)
|
||||||
|
|
||||||
|
∇(::typeof(+), a::IVertex, b::IVertex) =
|
||||||
|
v(1), v(1)
|
||||||
|
|
||||||
|
∇(::typeof(-), a::IVertex, b::IVertex) =
|
||||||
|
v(1), v(-1)
|
||||||
|
|
||||||
|
∇(::typeof(*), a::IVertex, b::IVertex) =
|
||||||
|
v(transpose, b), v(transpose, a)
|
||||||
|
|
||||||
|
function ∇v(v::IVertex, chain::Vector{IVertex}, out = d())
|
||||||
|
if isconstant(v)
|
||||||
|
@assert !haskey(out, value(v))
|
||||||
|
out[value(v)] = length(chain) == 1 ?
|
||||||
|
first(chain) :
|
||||||
|
foldl((x, y) -> vertex(*, x, y), chain)
|
||||||
|
else
|
||||||
|
∇s = ∇(value(v), inputs(v)...)
|
||||||
|
for (v′, ∇′) in zip(inputs(v), ∇s)
|
||||||
|
∇v(v′, (value(∇′) ≠ 1 ? push!(copy(chain), ∇′) : chain), out)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return out
|
||||||
|
end
|
||||||
|
|
||||||
|
∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain))
|
||||||
|
|
||||||
|
∇v(v::Vertex, ∂::Vertex) = ∇v(v, [∂])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user