convnet primitives

This commit is contained in:
Mike Innes 2018-03-06 19:58:05 +00:00
parent 0802b4d5cf
commit 36baa7ec2c

View File

@ -14,9 +14,27 @@ shape(::typeof(broadcast), f, xs...) =
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
shape(::typeof(reshape), x::Shape{T}, i...) where T =
Shape{T}(Base._reshape_uncolon(x, i))
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
# NNlib
using NNlib
using ..Tracker: _conv, _maxpool
shape(::typeof(softmax), x) = x
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
inplace!(::typeof(_conv), y, x, w, stride, pad) =
NNlib.conv!(y, x, w, stride = stride, pad = pad)
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
Shape{T}(NNlib.pdims(size(x), k, pad, k))
inplace!(::typeof(_maxpool), y, x, k, pad) =
NNlib.maxpool!(y, x, k, pad = pad)