address comments

This commit is contained in:
CarloLucibello 2017-10-25 02:35:27 +02:00
parent 536ab3861d
commit 711ea09d99
2 changed files with 12 additions and 7 deletions

View File

@ -121,4 +121,4 @@ end
Set model `m` in test mode if `val=true`, and in training mode otherwise.
This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers.
"""
testmode!(m, val::Bool=true) = prefor(x -> x isa Dropout && (x.testmode = val), m)
testmode!(m, val::Bool=true) = prefor(x -> :testmode fieldnames(x) && (x.testmode = val), m)

View File

@ -55,12 +55,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
==(x::TrackedArray, y) = data(x) == y
==(y, x::TrackedArray) = y == data(x)
==(x::TrackedScalar, y) = data(x)[] == y
==(y, x::TrackedScalar) = y == data(x)[]
<(x::TrackedScalar, y) = data(x)[] < y
<(x, y::TrackedScalar) = x < data(y)[]
#to be merged with data in the future
unbox(x::TrackedArray) = data(x)
unbox(x::TrackedScalar) = data(x)[]
==(x::TrackedArray, y) = unbox(x) == y
==(y, x::TrackedArray) = y == unbox(x)
==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x)
<(x::TrackedScalar, y) = unbox(x) < y
<(x, y::TrackedScalar) = x < unbox(y)
<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")