kf/tpu_wip

This commit is contained in:
Keno Fischer 2018-10-09 21:28:29 -04:00
parent 02ecca4c61
commit f98d289579
2 changed files with 16 additions and 11 deletions

View File

@ -17,14 +17,12 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
struct Conv{F,A,V,Stride,Pad,Dilation}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
Conv(σ::F, weight::A, bias::V, stride, pad, dilation) where {F,A,V} = Conv{F,A,V,stride,pad,dilation}(σ, weight, bias)
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
@ -35,13 +33,15 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike Conv
children(c::Conv) = (c.σ, c.weight, c.bias)
mapchildren(f, c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation}) where {stride, pad, dilation} =
Conv(f(c.σ), f(c.weight), f(c.bias), stride, pad, dilation)
function (c::Conv)(x)
function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride, pad, dilation}
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
σ.(conv(x, c.weight, stride = stride, pad = pad, dilation = dilation) .+ b)
end
function Base.show(io::IO, l::Conv)

View File

@ -33,11 +33,16 @@ end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
function rand_similar(x::AbstractArray)
y = similar(x)
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
y
end
function (a::Dropout)(x)
a.active || return x
y = rand_similar(x)
y = _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end