diff --git a/REQUIRE b/REQUIRE index df9c6322..7164de5a 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,4 +1,4 @@ -julia 0.7- +julia 0.7 Juno MacroTools 0.3.3 NNlib diff --git a/src/layers/basic.jl b/src/layers/basic.jl index d461c95c..f7344484 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -21,8 +21,8 @@ struct Chain Chain(xs...) = new([xs...]) end -@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! -@forward Chain.layers Base.start, Base.next, Base.done +@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push! +@forward Chain.layers Base.iterate children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 04f5c231..774123b4 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -70,7 +70,7 @@ struct Params Params(xs) = new(IdSet(xs)) end -@forward Params.params Base.start, Base.next, Base.done +@forward Params.params Base.iterate, Base.length function Base.show(io::IO, ps::Params) print(io, "Params([") @@ -86,6 +86,8 @@ Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") Grads() = Grads(IdDict()) +@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate + Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps)) Base.getindex(g::Grads, x::Tracked) = g.grads[x] @@ -94,7 +96,6 @@ function Base.getindex(g::Grads, x) g[tracker(x)] end -@forward Grads.grads Base.setindex!, Base.haskey accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl index 1bbfec09..62d5190e 100644 --- a/src/tracker/idset.jl +++ b/src/tracker/idset.jl @@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}() @forward IdSet.dict Base.length -Base.start(s::IdSet) = start(keys(s.dict)) -Base.next(s::IdSet, st) = next(keys(s.dict), st) -Base.done(s::IdSet, st) = done(keys(s.dict), st) +function Base.iterate(v::IdSet, state...) + y = Base.iterate(keys(v.dict), state...) + y === nothing && return nothing + return (y[1], y[2]) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6d698784..fcda4e82 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Flux, Test, Random +using Random -srand(0) +Random.seed!(0) @testset "Flux" begin diff --git a/test/tracker.jl b/test/tracker.jl index d504f0a4..768cc4f7 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -5,13 +5,13 @@ using NNlib: conv using Printf: @sprintf using LinearAlgebra: diagm, dot, LowerTriangular, norm using Statistics: mean, std +using Random # using StatsBase gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) -gradtest(f, dims...) = gradtest(f, rand.(dims)...) +gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @testset "Tracker" begin - @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) diff --git a/test/utils.jl b/test/utils.jl index 6fb28e31..5e1b0ef0 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,6 +1,7 @@ using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian using StatsBase: std using Dates +using Random @testset "Throttle" begin @testset "default behaviour" begin @@ -61,7 +62,7 @@ end @testset "Initialization" begin # Set random seed so that these tests don't fail randomly - srand(0) + Random.seed!(0) # initn() should yield a kernel with stddev ~= 1e-2 v = initn(10, 10) @test std(v) > 0.9*1e-2