use @ fix in a few places

This commit is contained in:
Mike J Innes 2018-03-01 16:31:20 +00:00
parent 5e84d52ee7
commit ac57fc3c26
4 changed files with 6 additions and 3 deletions

View File

@ -13,6 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
param, params, mapleaves, cpu, gpu
@reexport using NNlib
using NNlib: @fix
include("tracker/Tracker.jl")
using .Tracker

View File

@ -75,7 +75,7 @@ treelike(Dense)
function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
@fix σ.(W*x .+ b)
end
function Base.show(io::IO, l::Dense)

View File

@ -1,11 +1,11 @@
using NNlib: log_fast, logsoftmax, logσ
using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* log_fast.() .* weight) / size(y, 2)
return @fix -sum(y .* log.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)

View File

@ -49,8 +49,10 @@ include("numeric.jl")
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
import NNlib.cudata
import Adapt.adapt
cudata(x::TrackedArray) = data(x)
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end