kf/tpu_wip
This commit is contained in:
parent
02ecca4c61
commit
f98d289579
@ -17,14 +17,12 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct Conv{N,F,A,V}
|
||||
struct Conv{F,A,V,Stride,Pad,Dilation}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
Conv(σ::F, weight::A, bias::V, stride, pad, dilation) where {F,A,V} = Conv{F,A,V,stride,pad,dilation}(σ, weight, bias)
|
||||
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
@ -35,13 +33,15 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
children(c::Conv) = (c.σ, c.weight, c.bias)
|
||||
mapchildren(f, c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation}) where {stride, pad, dilation} =
|
||||
Conv(f(c.σ), f(c.weight), f(c.bias), stride, pad, dilation)
|
||||
|
||||
function (c::Conv)(x)
|
||||
function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride, pad, dilation}
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = stride, pad = pad, dilation = dilation) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
|
@ -33,11 +33,16 @@ end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
function rand_similar(x::AbstractArray)
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
y
|
||||
end
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
y = rand_similar(x)
|
||||
y = _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user