From 2d33f19346b48dd76559926b62ba1dd7cd978ba7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 29 Nov 2017 16:45:50 +0000 Subject: [PATCH 1/3] onehot unk arg --- src/onehot.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index f8061063..f94fb93e 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -42,7 +42,14 @@ function onehot(l, labels) OneHotVector(i, length(labels)) end -onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls]) +function onehot(l, labels, unk) + i = findfirst(labels, l) + i > 0 || return onehot(unk, labels) + OneHotVector(i, length(labels)) +end + +onehotbatch(ls, labels, unk...) = + OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls]) argmax(y::AbstractVector, labels = 1:length(y)) = labels[findfirst(y, maximum(y))] From 19039f48819835bf01ea6f2f69792f53dfe7d4f8 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 30 Nov 2017 13:37:38 +0000 Subject: [PATCH 2/3] export sigmoid --- src/Flux.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Flux.jl b/src/Flux.jl index df4b1636..7671ddd2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,7 +12,7 @@ export Chain, Dense, RNN, LSTM, Dropout, LayerNorm, param, params, mapleaves using NNlib -export σ, relu, leakyrelu, elu, swish, softmax +export σ, sigmoid, relu, leakyrelu, elu, swish, softmax include("tracker/Tracker.jl") using .Tracker From cab235a57863558aa060a28776f8934d5a0a0ed4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 30 Nov 2017 13:51:31 +0000 Subject: [PATCH 3/3] gpu compat --- src/tracker/Tracker.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3a64fcb7..74ed2d75 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -40,7 +40,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) isleaf(x::TrackedArray) = x.f == Call(nothing) -param(xs) = TrackedArray(AbstractFloat.(xs)) +param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs)) param(xs::Real) = param(fill(xs)) istracked(x::TrackedArray) = true