use @ fix in a few places
This commit is contained in:
parent
5e84d52ee7
commit
ac57fc3c26
@ -13,6 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv, Conv2D,
|
|||||||
param, params, mapleaves, cpu, gpu
|
param, params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
|
using NNlib: @fix
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
|
@ -75,7 +75,7 @@ treelike(Dense)
|
|||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
σ.(W*x .+ b)
|
@fix σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
using NNlib: log_fast, logsoftmax, logσ
|
using NNlib: logsoftmax, logσ
|
||||||
|
|
||||||
# Cost functions
|
# Cost functions
|
||||||
|
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2)
|
return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
@ -49,8 +49,10 @@ include("numeric.jl")
|
|||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
|
import NNlib.cudata
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
|
cudata(x::TrackedArray) = data(x)
|
||||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user