From 711ea09d99cc3cc8daf39b172c5a5be065f13d7f Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 25 Oct 2017 02:35:27 +0200 Subject: [PATCH] address comments --- src/layers/basic.jl | 2 +- src/tracker/Tracker.jl | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 841cf094..c15868ab 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -121,4 +121,4 @@ end Set model `m` in test mode if `val=true`, and in training mode otherwise. This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers. """ -testmode!(m, val::Bool=true) = prefor(x -> x isa Dropout && (x.testmode = val), m) +testmode!(m, val::Bool=true) = prefor(x -> :testmode ∈ fieldnames(x) && (x.testmode = val), m) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 1ab92f7e..90707ea5 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -55,12 +55,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -==(x::TrackedArray, y) = data(x) == y -==(y, x::TrackedArray) = y == data(x) -==(x::TrackedScalar, y) = data(x)[] == y -==(y, x::TrackedScalar) = y == data(x)[] -<(x::TrackedScalar, y) = data(x)[] < y -<(x, y::TrackedScalar) = x < data(y)[] +#to be merged with data in the future +unbox(x::TrackedArray) = data(x) +unbox(x::TrackedScalar) = data(x)[] + +==(x::TrackedArray, y) = unbox(x) == y +==(y, x::TrackedArray) = y == unbox(x) +==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x) + +<(x::TrackedScalar, y) = unbox(x) < y +<(x, y::TrackedScalar) = x < unbox(y) +<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}")