parent
0ef6456903
commit
5a023a9ccc
@ -10,7 +10,6 @@ export Chain, Dense, RNN, LSTM, GRU, Conv,
|
|||||||
params, mapleaves, cpu, gpu
|
params, mapleaves, cpu, gpu
|
||||||
|
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using NNlib: @fix
|
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
module CUDA
|
module CUDA
|
||||||
|
|
||||||
using CuArrays
|
using ..CuArrays
|
||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
using LinearAlgebra
|
using LinearAlgebra
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
@ -243,8 +243,8 @@ end
|
|||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
@ -326,7 +326,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -341,6 +341,6 @@ end
|
|||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||||
dWi.', dWh.', db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -77,7 +77,7 @@ end
|
|||||||
|
|
||||||
function (a::Dense)(x)
|
function (a::Dense)(x)
|
||||||
W, b, σ = a.W, a.b, a.σ
|
W, b, σ = a.W, a.b, a.σ
|
||||||
@fix σ.(W*x .+ b)
|
σ.(W*x .+ b)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
@ -33,8 +33,9 @@ import Adapt.adapt
|
|||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
import Base.Broadcast: BroadcastStyle, ArrayStyle
|
||||||
|
BroadcastStyle(::Type{<:OneHotMatrix{<:CuArray}}) = ArrayStyle{CuArray}()
|
||||||
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -381,3 +381,32 @@ function Base.copy(bc::Broadcast.Broadcasted{TrackedStyle})
|
|||||||
bc = Broadcast.flatten(bc)
|
bc = Broadcast.flatten(bc)
|
||||||
∇broadcast(bc.f, bc.args...)
|
∇broadcast(bc.f, bc.args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
using Requires
|
||||||
|
|
||||||
|
# https://github.com/FluxML/Flux.jl/issues/353
|
||||||
|
@init @eval Base.Broadcast begin
|
||||||
|
function flatten(bc::Broadcasted{Style}) where {Style}
|
||||||
|
isflat(bc) && return bc
|
||||||
|
args = cat_nested(bc)
|
||||||
|
let makeargs = make_makeargs(bc), f = bc.f
|
||||||
|
newf = @inline function(args::Vararg{Any,N}) where N
|
||||||
|
f(makeargs(args...)...)
|
||||||
|
end
|
||||||
|
return Broadcasted{Style}(newf, args, bc.axes)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}})
|
||||||
|
bc = t[1]
|
||||||
|
let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f
|
||||||
|
let makeargs = make_makeargs(makeargs, bc.args)
|
||||||
|
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
|
||||||
|
return @inline function(args::Vararg{Any,N}) where N
|
||||||
|
args1 = makeargs(args...)
|
||||||
|
a, b = headargs(args1...), tailargs(args1...)
|
||||||
|
(f(a...), b...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -14,6 +14,7 @@ cx = gpu(x)
|
|||||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||||
|
@test (cx .+ 1) isa CuArray
|
||||||
|
|
||||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
@ -3,6 +3,9 @@ using Random
|
|||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
|
# So we can use the system CuArrays
|
||||||
|
insert!(LOAD_PATH, 2, "@v#.#")
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
@ -12,7 +15,7 @@ include("layers/stateless.jl")
|
|||||||
include("optimise.jl")
|
include("optimise.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
if Base.find_package("CuArrays") ≠ nothing
|
if Base.find_package("CuArrays") != nothing
|
||||||
include("cuda/cuda.jl")
|
include("cuda/cuda.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user