diff --git a/src/jit/lib.jl b/src/jit/lib.jl index 42c8cac3..5301e579 100644 --- a/src/jit/lib.jl +++ b/src/jit/lib.jl @@ -14,9 +14,27 @@ shape(::typeof(broadcast), f, xs...) = inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...) +shape(::typeof(reshape), x::Shape{T}, i...) where T = + Shape{T}(Base._reshape_uncolon(x, i)) + +inplace!(::typeof(reshape), y, x, i...) = copy!(y, x) + # NNlib using NNlib +using ..Tracker: _conv, _maxpool shape(::typeof(softmax), x) = x inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x) + +shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T = + Shape{T}(NNlib.cdims(size(x), size(w), pad, stride)) + +inplace!(::typeof(_conv), y, x, w, stride, pad) = + NNlib.conv!(y, x, w, stride = stride, pad = pad) + +shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T = + Shape{T}(NNlib.pdims(size(x), k, pad, k)) + +inplace!(::typeof(_maxpool), y, x, k, pad) = + NNlib.maxpool!(y, x, k, pad = pad)