add == and < for tracked arrays
This commit is contained in:
parent
2e1ed4c3fc
commit
86c7c9246e
@ -115,6 +115,13 @@ function (a::Dropout)(x)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
setmode!(m, mode::Symbol)
|
||||
|
||||
Change the mode of model `m` to `mode`. Possible values for `mode` are
|
||||
`:train` and `:eval`.
|
||||
This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers.
|
||||
"""
|
||||
setmode!(a, mode::Symbol) = nothing
|
||||
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
|
||||
setmode!(a::Dropout, mode::Symbol) = a.mode = mode
|
||||
|
@ -1,5 +1,5 @@
|
||||
module Tracker
|
||||
|
||||
import Base: <, ==
|
||||
export TrackedArray, param, back!
|
||||
|
||||
data(x) = x
|
||||
@ -54,6 +54,13 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
==(x::TrackedArray, y) = data(x) == y
|
||||
==(y, x::TrackedArray) = y == data(x)
|
||||
==(x::TrackedScalar, y) = data(x)[] == y
|
||||
==(y, x::TrackedScalar) = y == data(x)[]
|
||||
<(x::TrackedScalar, y) = data(x)[] < y
|
||||
<(x, y::TrackedScalar) = x < data(y)[]
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
|
@ -16,8 +16,8 @@
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = m(x)
|
||||
@test count(a->a.data[] == 0, y) > 50
|
||||
@test count(a->a == 0, y) > 50
|
||||
setmode!(m, :eval)
|
||||
y = m(x)
|
||||
@test count(a->a.data[] == 0, y) == 0
|
||||
@test count(a->a == 0, y) == 0
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user