diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 088cf1e1..0c7e1fd0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -115,6 +115,13 @@ function (a::Dropout)(x) end end +""" + setmode!(m, mode::Symbol) + +Change the mode of model `m` to `mode`. Possible values for `mode` are +`:train` and `:eval`. +This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers. +""" setmode!(a, mode::Symbol) = nothing setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c) setmode!(a::Dropout, mode::Symbol) = a.mode = mode diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index a2a6c745..8f495f82 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,5 +1,5 @@ module Tracker - +import Base: <, == export TrackedArray, param, back! data(x) = x @@ -54,6 +54,13 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) +==(x::TrackedArray, y) = data(x) == y +==(y, x::TrackedArray) = y == data(x) +==(x::TrackedScalar, y) = data(x)[] == y +==(y, x::TrackedScalar) = y == data(x)[] +<(x::TrackedScalar, y) = data(x)[] < y +<(x, y::TrackedScalar) = x < data(y)[] + Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") diff --git a/test/layers.jl b/test/layers.jl index ead9c343..d0a5cbe1 100644 --- a/test/layers.jl +++ b/test/layers.jl @@ -16,8 +16,8 @@ m = Chain(Dense(100,100), Dropout(0.9)) y = m(x) - @test count(a->a.data[] == 0, y) > 50 + @test count(a->a == 0, y) > 50 setmode!(m, :eval) y = m(x) - @test count(a->a.data[] == 0, y) == 0 + @test count(a->a == 0, y) == 0 end