commit
86683e5991
@ -10,7 +10,6 @@ export Chain, Dense, RNN, LSTM, GRU, Conv,
|
|||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using NNlib: @fix
|
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
using LinearAlgebra
|
using LinearAlgebra
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
@ -19,8 +19,9 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||||||
desc = DropoutDesc(d[], states)
|
desc = DropoutDesc(d[], states)
|
||||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||||
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
desc,libcudnn_handle[],ρ,states,length(states),seed)
|
||||||
finalizer(desc, x ->
|
finalizer(desc) do x
|
||||||
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
@check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return desc
|
return desc
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -45,10 +46,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||||
wx = slice(0, (input, hidden*n))
|
wx = slice(0, (input, hidden*n))
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||||
(wx, wh), bias
|
(wx, wh), bias
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -88,8 +89,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd) do x
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return rd
|
return rd
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -243,8 +245,8 @@ end
|
|||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
@ -326,7 +328,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -341,6 +343,6 @@ end
|
|||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||||
dWi.', dWh.', db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -77,7 +77,7 @@ end
|
|||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
@fix σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
@ -33,8 +33,9 @@ import Adapt.adapt
|
|||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -370,14 +370,53 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
track(Call(back, tracker.(args)), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
using Base.Broadcast: BroadcastStyle
|
using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted
|
||||||
|
|
||||||
struct TrackedStyle <: BroadcastStyle end
|
struct TrackedStyle <: BroadcastStyle end
|
||||||
|
|
||||||
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle()
|
||||||
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle()
|
||||||
|
|
||||||
function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
# We have to re-build the original broadcast struct to get the appropriate array
|
||||||
bc = Broadcast.flatten(bc)
|
# style. We need this primarily to support CuArrays' broadcasting fixes.
|
||||||
∇broadcast(bc.f, bc.args...)
|
broadcast_rebuild(xs) = data(xs)
|
||||||
|
|
||||||
|
broadcast_rebuild(bc::Broadcasted) =
|
||||||
|
broadcasted(bc.f, broadcast_rebuild.(bc.args)...)
|
||||||
|
|
||||||
|
preprocess(x) = x
|
||||||
|
|
||||||
|
function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle})
|
||||||
|
bc1 = Broadcast.flatten(bc)
|
||||||
|
bc2 = Broadcast.flatten(broadcast_rebuild(bc))
|
||||||
|
∇broadcast(bc2.f, bc1.args...)
|
||||||
|
end
|
||||||
|
|
||||||
|
using Requires
|
||||||
|
|
||||||
|
# https://github.com/FluxML/Flux.jl/issues/353
|
||||||
|
@init @eval Base.Broadcast begin
|
||||||
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||||
|
isflat(bc) && return bc
|
||||||
|
args = cat_nested(bc)
|
||||||
|
let makeargs = make_makeargs(bc), f = bc.f
|
||||||
|
newf = @inline function(args::Vararg{Any,N}) where N
|
||||||
|
f(makeargs(args...)...)
|
||||||
|
end
|
||||||
|
return Broadcasted{Style}(newf, args, bc.axes)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||||
|
bc = t[1]
|
||||||
|
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||||
|
let makeargs = make_makeargs(makeargs, bc.args)
|
||||||
|
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||||
|
return @inline function(args::Vararg{Any,N}) where N
|
||||||
|
args1 = makeargs(args...)
|
||||||
|
a, b = headargs(args1...), tailargs(args1...)
|
||||||
|
(f(a...), b...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
@test (cx .+ 1) isa CuArray
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
@ -25,10 +26,9 @@ x = [1,2,3]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||||
|
|
||||||
# Fails in Pkg.test ffs
|
c = gpu(Conv((2,2),3=>4))
|
||||||
# c = gpu(Conv((2,2),3=>4))
|
l = c(gpu(rand(10,10,3,2)))
|
||||||
# l = c(gpu(rand(10,10,3,2)))
|
Flux.back!(sum(l))
|
||||||
# Flux.back!(sum(l))
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,8 +1,26 @@
|
|||||||
|
# Pkg.test runs with --check_bounds=1, forcing all bounds checks.
|
||||||
|
# This is incompatible with CUDAnative (see JuliaGPU/CUDAnative.jl#98)
|
||||||
|
if Base.JLOptions().check_bounds == 1
|
||||||
|
file = @__FILE__
|
||||||
|
run(```
|
||||||
|
$(Base.julia_cmd())
|
||||||
|
--color=$(Base.have_color ? "yes" : "no")
|
||||||
|
--compiled-modules=$(Bool(Base.JLOptions().use_compiled_modules) ? "yes" : "no")
|
||||||
|
--startup-file=$(Base.JLOptions().startupfile != 2 ? "yes" : "no")
|
||||||
|
--code-coverage=$(["none", "user", "all"][1+Base.JLOptions().code_coverage])
|
||||||
|
$(file)
|
||||||
|
```)
|
||||||
|
exit()
|
||||||
|
end
|
||||||
|
|
||||||
using Flux, Test, Random
|
using Flux, Test, Random
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
|
# So we can use the system CuArrays
|
||||||
|
insert!(LOAD_PATH, 2, "@v#.#")
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
@ -12,7 +30,7 @@ include("layers/stateless.jl")
|
|||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
if Base.find_package("CuArrays") ≠ nothing
|
if Base.find_package("CuArrays") != nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user