Compare commits

...

8 Commits

Author SHA1 Message Date
Keno Fischer
58e299eafb OneHotMatrix WIP 2019-02-13 16:48:13 -05:00
Keno Fischer
be0133fb67 Make LSTMCell non-mutable 2019-01-28 20:26:04 -05:00
Keno Fischer
770f601897 Some memory improvements to OneHotMatrix
1. Parameterize OneHotVector on Integer type, to avoid using more memory
   than required for vectors of them.
2. Switch OneHotMatrix from storing a vector of OneHotVectors to only storing
   the data and the size of the vector (reconstructing the vector locally), thus
   saving half the memory required and eliminating a transpose operation for
   matmul with OneHotMatrix on TPU.
2019-01-28 20:23:50 -05:00
Elliot Saba
a7143553df Change name to σ² for better consistency 2019-01-20 23:57:19 +00:00
Elliot Saba
c5d5a5c2a8 Cleanup BatchNorm implementation
This provides greater datatype persistence
2019-01-20 23:50:13 +00:00
Keno Fischer
943deea92d Improvements for ResNet 2018-10-28 15:30:25 -04:00
Keno Fischer
77bb2a66de Use lower level conv interface 2018-10-23 16:15:39 -04:00
Keno Fischer
f98d289579 kf/tpu_wip 2018-10-09 21:28:29 -04:00
5 changed files with 83 additions and 86 deletions

View File

@ -17,14 +17,12 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct Conv{N,F,A,V}
struct Conv{F,A,V,Stride,Pad,Dilation}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{N,Int}
dilation::NTuple{N,Int}
end
Conv(σ::F, weight::A, bias::V, stride, pad, dilation) where {F,A,V} = Conv{F,A,V,stride,pad,dilation}(σ, weight, bias)
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N} =
@ -35,13 +33,15 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
stride = stride, pad = pad, dilation = dilation)
@treelike Conv
children(c::Conv) = (c.σ, c.weight, c.bias)
mapchildren(f, c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation}) where {stride, pad, dilation} =
Conv(f(c.σ), f(c.weight), f(c.bias), stride, pad, dilation)
function (c::Conv)(x)
function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride, pad, dilation}
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
σ.(NNlib._conv{pad, stride, dilation}()(x, c.weight) .+ b)
end
function Base.show(io::IO, l::Conv)

View File

@ -33,11 +33,16 @@ end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
function rand_similar(x::AbstractArray)
y = similar(x)
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
y
end
function (a::Dropout)(x)
a.active || return x
y = rand_similar(x)
y = _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end
@ -67,7 +72,7 @@ function Base.show(io::IO, l::LayerNorm)
end
"""
BatchNorm(channels::Integer, σ = identity;
BatchNorm(channels::Integer, σ² = identity;
initβ = zeros, initγ = ones,
ϵ = 1e-8, momentum = .1)
@ -97,11 +102,11 @@ m = Chain(
```
"""
mutable struct BatchNorm{F,V,W,N}
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ::W # moving std
λ::F # activation function
β::V # bias
γ::V # scale
μ::W # moving mean
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
@ -118,37 +123,40 @@ function (BN::BatchNorm)(x)
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
affine_shape = let dims=dims, channels=channels
ntuple(i->i == dims - 1 ? channels : 1, dims)
end
m = let sz = size(x)
prod(ntuple(i->sz[i], dims-2)) * sz[end]
end
if !BN.active
μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...)
else
T = eltype(x)
T = eltype(data(x))
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes)
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
meansub = (x .- μ)
σ² = mean(meansub .* meansub, dims = axes)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
mtm = convert(T, data(BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(reshape(μ, :))
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* data(reshape(σ², :)) .* m ./ (m - 1))
end
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
let λ = BN.λ, ϵ = eltype(data(σ²))(BN.ϵ)
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...))
end
end
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)

View File

@ -112,12 +112,12 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
# LSTM
mutable struct LSTMCell{A,V}
struct LSTMCell{A,B,V,W,Z}
Wi::A
Wh::A
Wh::B
b::V
h::V
c::V
h::W
c::Z
end
function LSTMCell(in::Integer, out::Integer;

View File

@ -1,58 +1,74 @@
import Base: *
struct OneHotVector <: AbstractVector{Bool}
ix::UInt32
of::UInt32
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
ix::T
of::T
end
Base.size(xs::OneHotVector) = (Int64(xs.of),)
Base.size(xs::OneHotVector) = (Int(xs.of),)
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
height::Int
"""
A matrix of one-hot column vectors
"""
struct OneHotMatrix{height, A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
data::A
end
Flux.OneHotMatrix{height}(data::AbstractVector{<:Integer}) where {height} =
OneHotMatrix{height, typeof(data)}(data)
Flux.OneHotMatrix(height, data) = OneHotMatrix{height}(data)
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
function OneHotMatrix(xs::Vector{<:OneHotVector})
height = length(xs[1])
OneHotMatrix(height, map(xs) do x
length(x) == height || error("All one hot vectors must be the same length")
x.ix
end)
end
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
Base.size(xs::OneHotMatrix{height}) where {height} = (height, length(xs.data))
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], size(xs)[1])
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(size(xs)[1], xs.data[i])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
import Adapt.adapt
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
adapt(T, xs::OneHotMatrix) = OneHotMatrix(size(xs)[1], adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import .CuArrays: CuArray, cudaconvert
import Base.Broadcast: BroadcastStyle, ArrayStyle
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(size(x)[1], cudaconvert(x.data))
end
function onehot(l, labels)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
function onehotidx(l, labels)
i = findfirst(isequal(l), labels)
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
i
end
function onehot(l, labels, unk)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
function onehotidx(l, labels, unk)
i = findfirst(isequal(l), labels)
i !== nothing || return onehotidx(unk, labels)
i
end
onehot(l, labels, unk...) = OneHotVector(onehotidx(l, labels, unk...), length(labels))
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

View File

@ -421,30 +421,3 @@ function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
end
using Requires
# https://github.com/FluxML/Flux.jl/issues/353
@init Requires.isprecompiling() || @eval Base.Broadcast begin
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
args = cat_nested(bc)
let makeargs = make_makeargs(bc), f = bc.f
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
end
end
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
bc = t[1]
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
let makeargs = make_makeargs(makeargs, bc.args)
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
return @inline function(args::Vararg{Any,N}) where N
args1 = makeargs(args...)
a, b = headargs(args1...), tailargs(args1...)
(f(a...), b...)
end
end
end
end
end