From ac57fc3c26595daac468e4b849e6d7293b4cb94a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 1 Mar 2018 16:31:20 +0000 Subject: [PATCH] use @ fix in a few places --- src/Flux.jl | 1 + src/layers/basic.jl | 2 +- src/layers/stateless.jl | 4 ++-- src/tracker/Tracker.jl | 2 ++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 110fd966..47dd5be3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -13,6 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D, param, params, mapleaves, cpu, gpu @reexport using NNlib +using NNlib: @fix include("tracker/Tracker.jl") using .Tracker diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f93e6818..ad374643 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -75,7 +75,7 @@ treelike(Dense) function (a::Dense)(x) W, b, σ = a.W, a.b, a.σ - σ.(W*x .+ b) + @fix σ.(W*x .+ b) end function Base.show(io::IO, l::Dense) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 34683fbf..798e7b33 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,11 +1,11 @@ -using NNlib: log_fast, logsoftmax, logσ +using NNlib: logsoftmax, logσ # Cost functions mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) - return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) + return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2) end @deprecate logloss(x, y) crossentropy(x, y) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 1c467e1e..b00e97db 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -49,8 +49,10 @@ include("numeric.jl") param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) +import NNlib.cudata import Adapt.adapt +cudata(x::TrackedArray) = data(x) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) end