remove back!, update!

This commit is contained in:
Mike J Innes 2017-08-18 10:18:45 +01:00
parent 21089fea9c
commit f8482ff80c
7 changed files with 2 additions and 64 deletions

View File

@ -17,7 +17,7 @@ export @net, unroll, unroll1, @shapes,
# Zero Flux Given # Zero Flux Given
include("core.jl") include("core.jl")
import .FluxCore: back!, update!, graph import .FluxCore: graph
include("utils.jl") include("utils.jl")
include("params.jl") include("params.jl")

View File

@ -66,7 +66,6 @@ function process_type(ex)
quote quote
$(build_type(T, params)) $(build_type(T, params))
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args))))) $(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params)))) $(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
nothing nothing
end end

View File

@ -22,12 +22,6 @@ function (m::Stateful)(xs...)
return y return y
end end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
update!(m::Stateful, η) = update!(m.model, η)
# Seq Models # Seq Models
struct SeqModel struct SeqModel
@ -52,14 +46,6 @@ function (m::SeqModel)(xs...)
reseq(m.model(xs...)) reseq(m.model(xs...))
end end
function back!(m::SeqModel, args...)
args = seqtuple(args, 0)
# TODO: reseq
back!(m.model, args...)
end
update!(m::SeqModel, η) = update!(m.model, η)
graph(m::SeqModel) = graph(m.model) graph(m::SeqModel) = graph(m.model)
# Recurrent Graphs # Recurrent Graphs

View File

@ -3,23 +3,6 @@
module FluxCore module FluxCore
"""
back!(model, ΔY, X...) => ΔX
Backpropagate the gradient `ΔY` through the model `model`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs).
"""
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))")
"""
update!(model, η) => m
Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`.
"""
update!(m, η) = m
""" """
graph(model) => ::IVertex{Any} | nothing graph(model) => ::IVertex{Any} | nothing

View File

@ -7,18 +7,6 @@ end
@forward Chain.layers Base.start, Base.next, Base.done @forward Chain.layers Base.start, Base.next, Base.done
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
function back!(s::Chain, Δ, x)
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
push!(crumbs, layer(crumbs[end]))
end
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
x, layer = pack
back!(layer, Δ, x)
end
end
graph(s::Chain) = graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers) foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)

View File

@ -25,24 +25,6 @@ function test_recurrence(bk)
end end
end end
function test_back(bk)
@testset "Backward Pass" begin
xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10)
dm = bk(d)
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(1, 10), xs)
@test length(Δ[1]) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)
@test !(dm(xs) d(xs))
end
end
function test_stacktrace(bk) function test_stacktrace(bk)
@testset "Stack Traces" begin @testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15)) model = TLP(Affine(10, 20), Affine(21, 15))

View File

@ -1,5 +1,5 @@
using Flux, DataFlow, MacroTools, Base.Test using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze, stack, back!, update!, flatten using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten
using DataFlow: Line, Frame using DataFlow: Line, Frame
@testset "Flux" begin @testset "Flux" begin