diff --git a/src/tracker/array.jl b/src/tracker/array.jl index d23062d4..89bd05e4 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -199,7 +199,7 @@ end # NNlib using NNlib -import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, logσ, ∇logσ, conv2d, pool +import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool softmax(xs::TrackedArray) = track(softmax, xs) @@ -209,10 +209,6 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) -logσ(xs::TrackedArray) = TrackedArray(Call(logσ, xs)) - -back(::typeof(logσ), Δ, xs) = @back(xs, ∇logσ(Δ, data(xs))) - # TODO: can store kwargs efficiently in namedtuples _conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)