Compare commits
8 Commits
master
...
kf/tpu_wip
Author | SHA1 | Date | |
---|---|---|---|
![]() |
58e299eafb | ||
![]() |
be0133fb67 | ||
![]() |
770f601897 | ||
![]() |
a7143553df | ||
![]() |
c5d5a5c2a8 | ||
![]() |
943deea92d | ||
![]() |
77bb2a66de | ||
![]() |
f98d289579 |
@ -17,14 +17,12 @@ be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
struct Conv{N,F,A,V}
|
||||
struct Conv{F,A,V,Stride,Pad,Dilation}
|
||||
σ::F
|
||||
weight::A
|
||||
bias::V
|
||||
stride::NTuple{N,Int}
|
||||
pad::NTuple{N,Int}
|
||||
dilation::NTuple{N,Int}
|
||||
end
|
||||
Conv(σ::F, weight::A, bias::V, stride, pad, dilation) where {F,A,V} = Conv{F,A,V,stride,pad,dilation}(σ, weight, bias)
|
||||
|
||||
Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N} =
|
||||
@ -35,13 +33,15 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
children(c::Conv) = (c.σ, c.weight, c.bias)
|
||||
mapchildren(f, c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation}) where {stride, pad, dilation} =
|
||||
Conv(f(c.σ), f(c.weight), f(c.bias), stride, pad, dilation)
|
||||
|
||||
function (c::Conv)(x)
|
||||
function (c::Conv{<:Any, <:Any, <:Any, stride, pad, dilation})(x) where {stride, pad, dilation}
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
σ.(conv(x, c.weight, stride = c.stride, pad = c.pad, dilation = c.dilation) .+ b)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, stride)..., :, 1)
|
||||
σ.(NNlib._conv{pad, stride, dilation}()(x, c.weight) .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Conv)
|
||||
|
@ -33,11 +33,16 @@ end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
function rand_similar(x::AbstractArray)
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
y
|
||||
end
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
y = rand_similar(x)
|
||||
y = _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
@ -67,7 +72,7 @@ function Base.show(io::IO, l::LayerNorm)
|
||||
end
|
||||
|
||||
"""
|
||||
BatchNorm(channels::Integer, σ = identity;
|
||||
BatchNorm(channels::Integer, σ² = identity;
|
||||
initβ = zeros, initγ = ones,
|
||||
ϵ = 1e-8, momentum = .1)
|
||||
|
||||
@ -97,11 +102,11 @@ m = Chain(
|
||||
```
|
||||
"""
|
||||
mutable struct BatchNorm{F,V,W,N}
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ::W # moving std
|
||||
λ::F # activation function
|
||||
β::V # bias
|
||||
γ::V # scale
|
||||
μ::W # moving mean
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
@ -118,37 +123,40 @@ function (BN::BatchNorm)(x)
|
||||
γ, β = BN.γ, BN.β
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ones(Int, dims)
|
||||
affine_shape[end-1] = channels
|
||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
affine_shape = let dims=dims, channels=channels
|
||||
ntuple(i->i == dims - 1 ? channels : 1, dims)
|
||||
end
|
||||
m = let sz = size(x)
|
||||
prod(ntuple(i->sz[i], dims-2)) * sz[end]
|
||||
end
|
||||
|
||||
if !BN.active
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ = reshape(BN.σ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
else
|
||||
T = eltype(x)
|
||||
T = eltype(data(x))
|
||||
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
|
||||
meansub = (x .- μ)
|
||||
σ² = mean(meansub .* meansub, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
|
||||
mtm = convert(T, data(BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(reshape(μ, :))
|
||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* data(reshape(σ², :)) .* m ./ (m - 1))
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
||||
let λ = BN.λ, ϵ = eltype(data(σ²))(BN.ϵ)
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||
|
||||
|
@ -112,12 +112,12 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...))
|
||||
|
||||
# LSTM
|
||||
|
||||
mutable struct LSTMCell{A,V}
|
||||
struct LSTMCell{A,B,V,W,Z}
|
||||
Wi::A
|
||||
Wh::A
|
||||
Wh::B
|
||||
b::V
|
||||
h::V
|
||||
c::V
|
||||
h::W
|
||||
c::Z
|
||||
end
|
||||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
|
@ -1,58 +1,74 @@
|
||||
import Base: *
|
||||
|
||||
struct OneHotVector <: AbstractVector{Bool}
|
||||
ix::UInt32
|
||||
of::UInt32
|
||||
struct OneHotVector{T <: Integer} <: AbstractVector{Bool}
|
||||
ix::T
|
||||
of::T
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotVector) = (Int64(xs.of),)
|
||||
Base.size(xs::OneHotVector) = (Int(xs.of),)
|
||||
|
||||
Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
A::AbstractMatrix * b::OneHotVector = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
height::Int
|
||||
"""
|
||||
A matrix of one-hot column vectors
|
||||
"""
|
||||
struct OneHotMatrix{height, A<:AbstractVector{<:Integer}} <: AbstractMatrix{Bool}
|
||||
data::A
|
||||
end
|
||||
Flux.OneHotMatrix{height}(data::AbstractVector{<:Integer}) where {height} =
|
||||
OneHotMatrix{height, typeof(data)}(data)
|
||||
Flux.OneHotMatrix(height, data) = OneHotMatrix{height}(data)
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||
function OneHotMatrix(xs::Vector{<:OneHotVector})
|
||||
height = length(xs[1])
|
||||
OneHotMatrix(height, map(xs) do x
|
||||
length(x) == height || error("All one hot vectors must be the same length")
|
||||
x.ix
|
||||
end)
|
||||
end
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])
|
||||
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
Base.size(xs::OneHotMatrix{height}) where {height} = (height, length(xs.data))
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = OneHotVector(xs.data[i], size(xs)[1])
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs[:, j][i]
|
||||
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(size(xs)[1], xs.data[i])
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, B.data]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(xs)
|
||||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(size(xs)[1], adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(size(x)[1], cudaconvert(x.data))
|
||||
end
|
||||
|
||||
function onehot(l, labels)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
function onehotidx(l, labels)
|
||||
i = findfirst(isequal(l), labels)
|
||||
i !== nothing || error("Value $(repr(l; context=:limited=>true)) is not in labels")
|
||||
i
|
||||
end
|
||||
|
||||
function onehot(l, labels, unk)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
function onehotidx(l, labels, unk)
|
||||
i = findfirst(isequal(l), labels)
|
||||
i !== nothing || return onehotidx(unk, labels)
|
||||
i
|
||||
end
|
||||
|
||||
onehot(l, labels, unk...) = OneHotVector(onehotidx(l, labels, unk...), length(labels))
|
||||
|
||||
onehotbatch(ls, labels, unk...) =
|
||||
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])
|
||||
OneHotMatrix(length(labels), [onehotidx(l, labels, unk...) for l in ls])
|
||||
|
||||
onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
||||
|
||||
|
@ -421,30 +421,3 @@ function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||
end
|
||||
|
||||
using Requires
|
||||
|
||||
# https://github.com/FluxML/Flux.jl/issues/353
|
||||
@init Requires.isprecompiling() || @eval Base.Broadcast begin
|
||||
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||
isflat(bc) && return bc
|
||||
args = cat_nested(bc)
|
||||
let makeargs = make_makeargs(bc), f = bc.f
|
||||
newf = @inline function(args::Vararg{Any,N}) where N
|
||||
f(makeargs(args...)...)
|
||||
end
|
||||
return Broadcasted{Style}(newf, args, bc.axes)
|
||||
end
|
||||
end
|
||||
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||
bc = t[1]
|
||||
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||
let makeargs = make_makeargs(makeargs, bc.args)
|
||||
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||
return @inline function(args::Vararg{Any,N}) where N
|
||||
args1 = makeargs(args...)
|
||||
a, b = headargs(args1...), tailargs(args1...)
|
||||
(f(a...), b...)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user