From f98d289579f931a76eff4ae40a7cfb96ca8fdb87 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Tue, 9 Oct 2018 21:28:29 -0400 Subject: [PATCH] kf/tpu_wip --- src/layers/conv.jl | 16 ++++++++-------- src/layers/normalise.jl | 11 ++++++++--- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index dbf8ccf9..e3c6ae67 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -17,14 +17,12 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. Takes the keyword arguments `pad`, `stride` and `dilation`. """ -struct Conv{N,F,A,V} +struct Conv{F,A,V,Stride,Pad,Dilation} σ::F weight::A bias::V - stride::NTuple{N,Int} - pad::NTuple{N,Int} - dilation::NTuple{N,Int} end +Conv(σ::F, weight::A, bias::V, stride, pad, dilation) where {F,A,V} = Conv{F,A,V,stride,pad,dilation}(σ, weight, bias) Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} = @@ -35,13 +33,15 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ, stride = stride, pad = pad, dilation = dilation) -@treelike Conv +children(c::Conv) = (c.σ, c.weight, c.bias) +mapchildren(f, c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation}) where {stride, pad, dilation} = + Conv(f(c.σ), f(c.weight), f(c.bias), stride, pad, dilation) -function (c::Conv)(x) +function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride, pad, dilation} # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) - σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b) + σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1) + σ.(conv(x, c.weight, stride = stride, pad = pad, dilation = dilation) .+ b) end function Base.show(io::IO, l::Conv) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 164f6fa7..40e186dc 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -33,11 +33,16 @@ end _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) -function (a::Dropout)(x) - a.active || return x +function rand_similar(x::AbstractArray) y = similar(x) rand!(y) - y .= _dropout_kernel.(y, a.p, 1 - a.p) + y +end + +function (a::Dropout)(x) + a.active || return x + y = rand_similar(x) + y = _dropout_kernel.(y, a.p, 1 - a.p) return x .* y end