use @ fix in a few places
This commit is contained in:
parent
5e84d52ee7
commit
ac57fc3c26
@ -13,6 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
||||
param, params, mapleaves, cpu, gpu
|
||||
|
||||
@reexport using NNlib
|
||||
using NNlib: @fix
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
|
@ -75,7 +75,7 @@ treelike(Dense)
|
||||
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
σ.(W*x .+ b)
|
||||
@fix σ.(W*x .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Dense)
|
||||
|
@ -1,11 +1,11 @@
|
||||
using NNlib: log_fast, logsoftmax, logσ
|
||||
using NNlib: logsoftmax, logσ
|
||||
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
||||
return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
@ -49,8 +49,10 @@ include("numeric.jl")
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
import NNlib.cudata
|
||||
import Adapt.adapt
|
||||
|
||||
cudata(x::TrackedArray) = data(x)
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user