convnet primitives
This commit is contained in:
parent
0802b4d5cf
commit
36baa7ec2c
@ -14,9 +14,27 @@ shape(::typeof(broadcast), f, xs...) =
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||
|
||||
shape(::typeof(reshape), x::Shape{T}, i...) where T =
|
||||
Shape{T}(Base._reshape_uncolon(x, i))
|
||||
|
||||
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
using ..Tracker: _conv, _maxpool
|
||||
|
||||
shape(::typeof(softmax), x) = x
|
||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||
|
||||
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
|
||||
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
|
||||
|
||||
inplace!(::typeof(_conv), y, x, w, stride, pad) =
|
||||
NNlib.conv!(y, x, w, stride = stride, pad = pad)
|
||||
|
||||
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
|
||||
Shape{T}(NNlib.pdims(size(x), k, pad, k))
|
||||
|
||||
inplace!(::typeof(_maxpool), y, x, k, pad) =
|
||||
NNlib.maxpool!(y, x, k, pad = pad)
|
||||
|
Loading…
Reference in New Issue
Block a user