address comments
This commit is contained in:
parent
536ab3861d
commit
711ea09d99
@ -121,4 +121,4 @@ end
|
||||
Set model `m` in test mode if `val=true`, and in training mode otherwise.
|
||||
This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers.
|
||||
"""
|
||||
testmode!(m, val::Bool=true) = prefor(x -> x isa Dropout && (x.testmode = val), m)
|
||||
testmode!(m, val::Bool=true) = prefor(x -> :testmode ∈ fieldnames(x) && (x.testmode = val), m)
|
||||
|
@ -55,12 +55,17 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
==(x::TrackedArray, y) = data(x) == y
|
||||
==(y, x::TrackedArray) = y == data(x)
|
||||
==(x::TrackedScalar, y) = data(x)[] == y
|
||||
==(y, x::TrackedScalar) = y == data(x)[]
|
||||
<(x::TrackedScalar, y) = data(x)[] < y
|
||||
<(x, y::TrackedScalar) = x < data(y)[]
|
||||
#to be merged with data in the future
|
||||
unbox(x::TrackedArray) = data(x)
|
||||
unbox(x::TrackedScalar) = data(x)[]
|
||||
|
||||
==(x::TrackedArray, y) = unbox(x) == y
|
||||
==(y, x::TrackedArray) = y == unbox(x)
|
||||
==(x::TrackedArray, y::TrackedArray) = unbox(x) == unbox(x)
|
||||
|
||||
<(x::TrackedScalar, y) = unbox(x) < y
|
||||
<(x, y::TrackedScalar) = x < unbox(y)
|
||||
<(x::TrackedScalar, y::TrackedScalar) = unbox(x) < unbox(y)
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
Loading…
Reference in New Issue
Block a user