Updates for julia 1.0
This commit is contained in:
parent
62d594af43
commit
5186e3ba18
|
@ -21,8 +21,8 @@ struct Chain
|
|||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
||||
@forward Chain.layers Base.iterate
|
||||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
|
|
|
@ -70,7 +70,7 @@ struct Params
|
|||
Params(xs) = new(IdSet(xs))
|
||||
end
|
||||
|
||||
@forward Params.params Base.start, Base.next, Base.done
|
||||
@forward Params.params Base.iterate
|
||||
|
||||
function Base.show(io::IO, ps::Params)
|
||||
print(io, "Params([")
|
||||
|
@ -86,6 +86,8 @@ Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
|
|||
|
||||
Grads() = Grads(IdDict())
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey, Base.length, Base.iterate
|
||||
|
||||
Grads(ps::Params) = Grads(IdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
|
@ -94,7 +96,6 @@ function Base.getindex(g::Grads, x)
|
|||
g[tracker(x)]
|
||||
end
|
||||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
Base.start(s::IdSet) = start(keys(s.dict))
|
||||
Base.next(s::IdSet, st) = next(keys(s.dict), st)
|
||||
Base.done(s::IdSet, st) = done(keys(s.dict), st)
|
||||
function iterate(v::IdSet, state...)
|
||||
y = iterate(keys(v.dict), state...)
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
|
@ -1,6 +1,7 @@
|
|||
using Flux, Test, Random
|
||||
using Random
|
||||
|
||||
srand(0)
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Flux" begin
|
||||
|
||||
|
|
|
@ -5,13 +5,13 @@ using NNlib: conv
|
|||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
|
||||
@testset "Tracker" begin
|
||||
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using StatsBase: std
|
||||
using Dates
|
||||
using Random
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
|
@ -61,7 +62,7 @@ end
|
|||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
srand(0)
|
||||
Random.seed!(0)
|
||||
# initn() should yield a kernel with stddev ~= 1e-2
|
||||
v = initn(10, 10)
|
||||
@test std(v) > 0.9*1e-2
|
||||
|
|
Loading…
Reference in New Issue