diff --git a/examples/MNIST.jl b/examples/MNIST.jl index 88970f0f..e768e166 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -1,5 +1,5 @@ using Flux, MNIST -using Flux: accuracy +using Flux: accuracy, onehot data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000] train = data[1:50_000] diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 63c4f4e2..b34e23a7 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -28,10 +28,13 @@ graph(::typeof(all), x, dim=nothing) = TensorFlow.reduce_all(x;axis=dim) graph(::typeof(any), x, dim=nothing) = TensorFlow.reduce_any(x;axis=dim) graph(::typeof(mean), x, dim=nothing) = TensorFlow.reduce_mean(x;axis=dim) graph(::typeof(svd), x) = svd(x) +graph(::typeof(size), x, dim) = TensorFlow.size(x,convert(Tensor{Int32}, dim)) +graph(::typeof(size), x) = TensorFlow.size(x) +graph(::typeof(chol), args...) = TensorFlow.transpose(TensorFlow.cholesky(args...)) for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos, sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj, - inv, det) + inv, det, transpose, permutedims, cat, length, diag, diagm) @eval graph(::typeof($op), args...) = $op(args...) end diff --git a/src/core.jl b/src/core.jl index d3953849..66e33440 100644 --- a/src/core.jl +++ b/src/core.jl @@ -6,11 +6,11 @@ module FluxCore """ back!(model, ΔY, X...) => ΔX -Backpropagate the gradient `ΔY` through the model `m`, accumulating the +Backpropagate the gradient `ΔY` through the model `model`, accumulating the gradients of any parameters. Returns the gradient of the input `X`. Gradients may be arrays or tuples of arrays (for multiple inputs/outputs). """ -back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))") +back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))") """ update!(model, η) => m diff --git a/src/layers/affine.jl b/src/layers/affine.jl index 9608efcc..ca79c004 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -9,3 +9,16 @@ Affine(in::Integer, out::Integer; init = initn) = inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) = Affine(in[1][2], out) + +function back!(m::Affine, Δ, x) + W, b = m.W, m.b + W.Δx[:] = x' * Δ + b.Δx[:] = sum(Δ, 1) + Δ * W.x' +end + +function update!(m::Affine, η) + update!(m.W, η) + update!(m.b, η) + m +end diff --git a/src/layers/control.jl b/src/layers/control.jl index 7851f902..d0c5e61b 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -7,9 +7,19 @@ end @forward Chain.layers Base.start, Base.next, Base.done (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) -back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) +function back!(s::Chain, Δ, x) + crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer + push!(crumbs, layer(crumbs[end])) + end + + foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ + x, layer = pack + back!(layer, Δ, x) + end +end + graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index 178de761..1dcfdf53 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -29,6 +29,24 @@ end @test A ≈ u*diagm(s)*transpose(v) @test tf(@net x -> inv(x))(A) ≈ inv(A) @test tf(@net x -> det(x))(A) ≈ det(A) + A = randn(Float32,(6,3)) + @test tf(@net x -> transpose(x))(A) ≈ transpose(A) + A = randn(Float32,(6,3,2)) + @test tf(@net (x,y) -> permutedims(x,y))(A,[3,2,1]) ≈ permutedims(A,[3,2,1]) + A1 = randn(Float32,(4,1)) + A2 = randn(Float32,(4,1)) + @test tf(@net (x,y) -> cat(2,x,y))(A1,A2) ≈ cat(2,A1,A2) + @test tf(@net x -> length(x))(A1) == length(A1) + A = randn(Float32,(5,5)) + @test tf(@net x -> diag(x))(A) ≈ diag(A) + A = randn(Float32,(5,)) + @test tf(@net x -> diagm(x))(A) ≈ diagm(A) + A = randn(4,5) + @test tf(@net x -> size(x))(A) == [4,5] + @test tf(@net (x,y) -> size(x,y))(A,1) == 4 + A = randn(6,5) + A = A'*A + @test tf(@net x -> chol(x))(A) ≈ chol(A) end end diff --git a/test/optimizer.jl b/test/optimizer.jl new file mode 100644 index 00000000..57f1d011 --- /dev/null +++ b/test/optimizer.jl @@ -0,0 +1,38 @@ +@testset "training julia models" begin + + @testset "linear regression" begin + srand(0) + + model = Affine(10, 1) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 3rand(Float32) + end + + Flux.train!(model, data, epoch=5) + + @test cor(reshape.((model.W.x, truth), 10)...) > .99 + end + + @testset "logistic regression" begin + srand(0) + + model = Chain(Affine(10, 1), σ) + + truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]' + + data = map(1:256) do i + x = rand(Float32, 10) + x, truth * x + 2rand(Float32) > 5f0 + end + + Flux.train!(model, data, epoch=10) + + @test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99 + end + +end + diff --git a/test/runtests.jl b/test/runtests.jl index 8dd1dd8e..1e4981f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,5 +15,7 @@ include("backend/common.jl") include("basic.jl") include("recurrent.jl") +include("optimizer.jl") + @tfonly include("backend/tensorflow.jl") @mxonly include("backend/mxnet.jl")