add == and < for tracked arrays
This commit is contained in:
parent
2e1ed4c3fc
commit
86c7c9246e
@ -115,6 +115,13 @@ function (a::Dropout)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
setmode!(m, mode::Symbol)
|
||||||
|
|
||||||
|
Change the mode of model `m` to `mode`. Possible values for `mode` are
|
||||||
|
`:train` and `:eval`.
|
||||||
|
This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers.
|
||||||
|
"""
|
||||||
setmode!(a, mode::Symbol) = nothing
|
setmode!(a, mode::Symbol) = nothing
|
||||||
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
|
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
|
||||||
setmode!(a::Dropout, mode::Symbol) = a.mode = mode
|
setmode!(a::Dropout, mode::Symbol) = a.mode = mode
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
import Base: <, ==
|
||||||
export TrackedArray, param, back!
|
export TrackedArray, param, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
@ -54,6 +54,13 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
|
==(x::TrackedArray, y) = data(x) == y
|
||||||
|
==(y, x::TrackedArray) = y == data(x)
|
||||||
|
==(x::TrackedScalar, y) = data(x)[] == y
|
||||||
|
==(y, x::TrackedScalar) = y == data(x)[]
|
||||||
|
<(x::TrackedScalar, y) = data(x)[] < y
|
||||||
|
<(x, y::TrackedScalar) = x < data(y)[]
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
|
@ -16,8 +16,8 @@
|
|||||||
m = Chain(Dense(100,100),
|
m = Chain(Dense(100,100),
|
||||||
Dropout(0.9))
|
Dropout(0.9))
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a.data[] == 0, y) > 50
|
@test count(a->a == 0, y) > 50
|
||||||
setmode!(m, :eval)
|
setmode!(m, :eval)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a.data[] == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user