Fix Optimizers
This commit is contained in:
parent
355091b9d1
commit
5db7a3a3ad
|
@ -9,7 +9,7 @@ struct Param{T}
|
|||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zero(x))
|
||||
Param(x::AbstractArray) = Param(x, zero(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
|
@ -17,6 +17,7 @@ include("train.jl")
|
|||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
Param(x::TrackedArray) = Param(x.data, x.grad)
|
||||
# Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
|
||||
end
|
||||
|
|
|
@ -20,13 +20,8 @@ Base.similar(s::IdSet, T::Type) = IdSet{T}()
|
|||
|
||||
@forward IdSet.dict Base.length
|
||||
|
||||
<<<<<<< HEAD
|
||||
function iterate(v::IdSet, state...)
|
||||
y = iterate(keys(v.dict), state...)
|
||||
=======
|
||||
function Base.iterate(v::IdSet, state...)
|
||||
y = Base.iterate(keys(v.dict), state...)
|
||||
>>>>>>> 837e03613f98ff9b949815018cba02a3682dab3c
|
||||
y === nothing && return nothing
|
||||
return (y[1], y[2])
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,21 +5,15 @@ Random.seed!(0)
|
|||
|
||||
@testset "Flux" begin
|
||||
|
||||
println("Testing")
|
||||
include("utils.jl")
|
||||
# println("Testing")
|
||||
# include("tracker.jl")
|
||||
println("Testing")
|
||||
include("tracker.jl")
|
||||
include("layers/normalisation.jl")
|
||||
println("Testing")
|
||||
include("layers/stateless.jl")
|
||||
println("Testing")
|
||||
include("optimise.jl")
|
||||
println("Testing")
|
||||
include("data.jl")
|
||||
|
||||
# if Base.find_in_path("CuArrays") ≠ nothing
|
||||
# include("cuda/cuda.jl")
|
||||
# end
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue