diff --git a/src/onehot.jl b/src/onehot.jl index 21524135..307611cc 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -28,7 +28,7 @@ Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j] Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs -# handle special case for when we want the entire column without allocating +# handle special case for when we want the entire column function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray} res = similar(xs, size(xs, 1), 1) if length(ot) == size(xs, 1)