add == and < for tracked arrays

This commit is contained in:
CarloLucibello 2017-10-23 11:41:08 +02:00
parent 2e1ed4c3fc
commit 86c7c9246e
3 changed files with 17 additions and 3 deletions

View File

@ -115,6 +115,13 @@ function (a::Dropout)(x)
end
end
"""
setmode!(m, mode::Symbol)
Change the mode of model `m` to `mode`. Possible values for `mode` are
`:train` and `:eval`.
This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers.
"""
setmode!(a, mode::Symbol) = nothing
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
setmode!(a::Dropout, mode::Symbol) = a.mode = mode

View File

@ -1,5 +1,5 @@
module Tracker
import Base: <, ==
export TrackedArray, param, back!
data(x) = x
@ -54,6 +54,13 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
==(x::TrackedArray, y) = data(x) == y
==(y, x::TrackedArray) = y == data(x)
==(x::TrackedScalar, y) = data(x)[] == y
==(y, x::TrackedScalar) = y == data(x)[]
<(x::TrackedScalar, y) = data(x)[] < y
<(x, y::TrackedScalar) = x < data(y)[]
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")

View File

@ -16,8 +16,8 @@
m = Chain(Dense(100,100),
Dropout(0.9))
y = m(x)
@test count(a->a.data[] == 0, y) > 50
@test count(a->a == 0, y) > 50
setmode!(m, :eval)
y = m(x)
@test count(a->a.data[] == 0, y) == 0
@test count(a->a == 0, y) == 0
end