Updates for julia 1.0

This commit is contained in:
Josh Christie 2018-08-11 10:51:07 +01:00
parent 62d594af43
commit 5186e3ba18
6 changed files with 16 additions and 11 deletions

View File

@ -21,8 +21,8 @@ struct Chain
Chain(xs...) = new([xs...])
end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
@forward Chain.layers Base.iterate
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)

View File

@ -70,7 +70,7 @@ struct Params
Params(xs) = new(IdSet(xs))
end
@forward Params.params Base.start, Base.next, Base.done
@forward Params.params Base.iterate
function Base.show(io::IO, ps::Params)
print(io, "Params([")
@ -86,6 +86,8 @@ Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(IdDict())
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
@ -94,7 +96,6 @@ function Base.getindex(g::Grads, x)
g[tracker(x)]
end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ

View File

@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
Base.start(s::IdSet) = start(keys(s.dict))
Base.next(s::IdSet, st) = next(keys(s.dict), st)
Base.done(s::IdSet, st) = done(keys(s.dict), st)
function iterate(v::IdSet, state...)
y = iterate(keys(v.dict), state...)
y === nothing && return nothing
return (y[1], y[2])
end

View File

@ -1,6 +1,7 @@
using Flux, Test, Random
using Random
srand(0)
Random.seed!(0)
@testset "Flux" begin

View File

@ -5,13 +5,13 @@ using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "Tracker" begin
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)

View File

@ -1,6 +1,7 @@
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
using StatsBase: std
using Dates
using Random
@testset "Throttle" begin
@testset "default behaviour" begin
@ -61,7 +62,7 @@ end
@testset "Initialization" begin
# Set random seed so that these tests don't fail randomly
srand(0)
Random.seed!(0)
# initn() should yield a kernel with stddev ~= 1e-2
v = initn(10, 10)
@test std(v) > 0.9*1e-2