remove back!, update!
This commit is contained in:
parent
21089fea9c
commit
f8482ff80c
@ -17,7 +17,7 @@ export @net, unroll, unroll1, @shapes,
|
||||
# Zero Flux Given
|
||||
|
||||
include("core.jl")
|
||||
import .FluxCore: back!, update!, graph
|
||||
import .FluxCore: graph
|
||||
|
||||
include("utils.jl")
|
||||
include("params.jl")
|
||||
|
@ -66,7 +66,6 @@ function process_type(ex)
|
||||
quote
|
||||
$(build_type(T, params))
|
||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
||||
nothing
|
||||
end
|
||||
|
@ -22,12 +22,6 @@ function (m::Stateful)(xs...)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
||||
|
||||
# Seq Models
|
||||
|
||||
struct SeqModel
|
||||
@ -52,14 +46,6 @@ function (m::SeqModel)(xs...)
|
||||
reseq(m.model(xs...))
|
||||
end
|
||||
|
||||
function back!(m::SeqModel, args...)
|
||||
args = seqtuple(args, 0)
|
||||
# TODO: reseq
|
||||
back!(m.model, args...)
|
||||
end
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
||||
graph(m::SeqModel) = graph(m.model)
|
||||
|
||||
# Recurrent Graphs
|
||||
|
17
src/core.jl
17
src/core.jl
@ -3,23 +3,6 @@
|
||||
|
||||
module FluxCore
|
||||
|
||||
"""
|
||||
back!(model, ΔY, X...) => ΔX
|
||||
|
||||
Backpropagate the gradient `ΔY` through the model `model`, accumulating the
|
||||
gradients of any parameters. Returns the gradient of the input `X`. Gradients
|
||||
may be arrays or tuples of arrays (for multiple inputs/outputs).
|
||||
"""
|
||||
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))")
|
||||
|
||||
"""
|
||||
update!(model, η) => m
|
||||
|
||||
Update the parameters of the model `m` using the accumulated gradients from
|
||||
`back!`, using the learning rate `η`.
|
||||
"""
|
||||
update!(m, η) = m
|
||||
|
||||
"""
|
||||
graph(model) => ::IVertex{Any} | nothing
|
||||
|
||||
|
@ -7,18 +7,6 @@ end
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||
|
||||
function back!(s::Chain, Δ, x)
|
||||
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
|
||||
push!(crumbs, layer(crumbs[end]))
|
||||
end
|
||||
|
||||
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
|
||||
x, layer = pack
|
||||
back!(layer, Δ, x)
|
||||
end
|
||||
end
|
||||
|
||||
graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
@ -25,24 +25,6 @@ function test_recurrence(bk)
|
||||
end
|
||||
end
|
||||
|
||||
function test_back(bk)
|
||||
@testset "Backward Pass" begin
|
||||
xs, ys = rand(1, 20), rand(1, 20)
|
||||
d = Affine(20, 10)
|
||||
dm = bk(d)
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(1, 10), xs)
|
||||
@test length(Δ[1]) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test !(dm(xs) ≈ d′(xs))
|
||||
end
|
||||
end
|
||||
|
||||
function test_stacktrace(bk)
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
|
@ -1,5 +1,5 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, squeeze, unsqueeze, stack, back!, update!, flatten
|
||||
using Flux: graph, Param, squeeze, unsqueeze, stack, update!, flatten
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
@testset "Flux" begin
|
||||
|
Loading…
Reference in New Issue
Block a user