Fix conflict
This commit is contained in:
commit
dfd680646c
|
@ -6,9 +6,9 @@ version = "0.2.0"
|
|||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra", "Test"]
|
||||
git-tree-sha1 = "a1245c11af6876245c32f82f2067bf67f7da8cee"
|
||||
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "0.4.0"
|
||||
version = "0.4.1"
|
||||
|
||||
[[Base64]]
|
||||
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
@ -26,10 +26,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
|||
version = "0.5.2"
|
||||
|
||||
[[CodecZlib]]
|
||||
deps = ["BinaryProvider", "Libdl", "Pkg", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "83cb3d65c37ea1364c2d5bf7bcea41843ba645dc"
|
||||
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
|
||||
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
|
||||
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
|
||||
[[ColorTypes]]
|
||||
deps = ["FixedPointNumbers", "Random", "Test"]
|
||||
|
@ -38,7 +38,7 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
|
|||
version = "0.7.5"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Pkg", "Printf", "Reexport", "Test"]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
|
||||
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.9.5"
|
||||
|
@ -56,7 +56,7 @@ uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
|
|||
version = "1.3.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "REPL", "Random", "Serialization", "Test"]
|
||||
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.14.0"
|
||||
|
@ -86,16 +86,16 @@ deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
|
|||
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
deps = ["Pkg", "Test"]
|
||||
deps = ["Test"]
|
||||
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
|
||||
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
|
||||
version = "0.5.3"
|
||||
|
||||
[[ForwardDiff]]
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Pkg", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "d8f3e0f19d0d546aa92eb1cd67cd3e515768d9f7"
|
||||
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
|
||||
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
|
||||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.0"
|
||||
version = "0.10.1"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["LinearAlgebra", "Markdown"]
|
||||
|
@ -132,9 +132,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
|
|||
|
||||
[[Media]]
|
||||
deps = ["MacroTools", "Test"]
|
||||
git-tree-sha1 = "9f390271c9a43dcbe908a10b5b9632cf58cbab5b"
|
||||
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
|
||||
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
|
||||
version = "0.4.1"
|
||||
version = "0.5.0"
|
||||
|
||||
[[Missings]]
|
||||
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
|
||||
|
@ -158,7 +158,7 @@ uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
|
|||
version = "0.3.2"
|
||||
|
||||
[[OrderedCollections]]
|
||||
deps = ["Pkg", "Random", "Serialization", "Test"]
|
||||
deps = ["Random", "Serialization", "Test"]
|
||||
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
|
||||
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
|
||||
version = "1.0.2"
|
||||
|
@ -220,12 +220,12 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
|||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
|
||||
git-tree-sha1 = "c35c9c76008babf4d658060fc64aeb369a41e7bd"
|
||||
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.7.1"
|
||||
version = "0.7.2"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "Statistics", "Test"]
|
||||
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
|
||||
git-tree-sha1 = "ebc5c2a27d91d5ec611a9861168182e2168effd3"
|
||||
uuid = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
version = "0.9.2"
|
||||
|
@ -245,7 +245,7 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
|||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["DelimitedFiles", "Pkg", "Random", "Test"]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
|
||||
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
|
||||
version = "0.8.1"
|
||||
|
|
2
REQUIRE
2
REQUIRE
|
@ -3,7 +3,7 @@ Juno
|
|||
MacroTools 0.3.3
|
||||
NNlib
|
||||
Requires
|
||||
Adapt
|
||||
Adapt 0.4
|
||||
CodecZlib
|
||||
Colors
|
||||
ZipFile
|
||||
|
|
|
@ -2,9 +2,20 @@ module CUDA
|
|||
|
||||
using ..CuArrays
|
||||
|
||||
if !applicable(CuArray{UInt8}, undef, 1)
|
||||
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
if isdefined(CuArrays, :libcudnn_handle)
|
||||
handle() = CuArrays.libcudnn_handle[]
|
||||
else
|
||||
handle() = CuArrays.CUDNN.handle()
|
||||
end
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl"
|
||||
else
|
||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
import ..Flux: data
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
|
@ -14,7 +15,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||
s = Csize_t[0]
|
||||
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
|
||||
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
|
||||
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
|
||||
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
|
||||
desc = DropoutDesc(d[], states)
|
||||
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
|
||||
desc,handle(),ρ,states,length(states),seed)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||
using LinearAlgebra
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
|
@ -80,12 +79,12 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
|||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(1)]
|
||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(bytes))
|
||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||
|
@ -147,7 +146,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, t
|
|||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train == Val{true} ?
|
||||
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||
nothing
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
|
@ -232,9 +231,6 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
|
|
|
@ -26,7 +26,6 @@ end
|
|||
|
||||
children(c::Chain) = c.layers
|
||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||
|
||||
|
@ -114,3 +113,11 @@ end
|
|||
function Base.show(io::IO, l::Diagonal)
|
||||
print(io, "Diagonal(", length(l.α), ")")
|
||||
end
|
||||
|
||||
# Try to avoid hitting generic matmul in some simple cases
|
||||
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
|
|
@ -37,7 +37,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
|||
|
||||
@treelike Conv
|
||||
|
||||
function (c::Conv)(x)
|
||||
function (c::Conv)(x::AbstractArray)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
|
@ -51,6 +51,12 @@ function Base.show(io::IO, l::Conv)
|
|||
print(io, ")")
|
||||
end
|
||||
|
||||
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
invoke(a, Tuple{AbstractArray}, x)
|
||||
|
||||
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||
a(T.(x))
|
||||
|
||||
"""
|
||||
DepthwiseConv(size, in)
|
||||
DepthwiseConv(size, in=>mul)
|
||||
|
|
|
@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
|
|||
|
||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||
|
||||
import Adapt.adapt
|
||||
import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||
|
||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||
import .CuArrays: CuArray, cudaconvert
|
||||
|
|
|
@ -5,7 +5,8 @@ using MacroTools: @q, @forward
|
|||
|
||||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
|
@ -99,7 +100,8 @@ end
|
|||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
@ -108,8 +110,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
|
|||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import Adapt.adapt
|
||||
import Adapt: adapt, adapt_structure
|
||||
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
|
|
@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
|
|||
return
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
|
@ -162,20 +171,11 @@ function losscheck(x)
|
|||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient(f, args...)
|
||||
function gradient_nested(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
derivative(f, x) = gradient(f, x)[1]
|
||||
|
||||
# Non-nesting versions
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
grad.(xs)
|
||||
end
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
|
10
src/utils.jl
10
src/utils.jl
|
@ -1,6 +1,12 @@
|
|||
# Arrays
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
|
||||
|
||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||
zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
||||
|
||||
ones(dims...) = Base.ones(Float32, dims...)
|
||||
zeros(dims...) = Base.zeros(Float32, dims...)
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ using Flux, Test
|
|||
using Flux: maxpool, meanpool
|
||||
|
||||
@testset "Pooling" begin
|
||||
x = randn(10, 10, 3, 2)
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
mp = MaxPool((2, 2))
|
||||
@test mp(x) == maxpool(x, (2,2))
|
||||
mp = MeanPool((2, 2))
|
||||
|
@ -10,7 +10,7 @@ using Flux: maxpool, meanpool
|
|||
end
|
||||
|
||||
@testset "CNN" begin
|
||||
r = zeros(28, 28, 1, 5)
|
||||
r = zeros(Float32, 28, 28, 1, 5)
|
||||
m = Chain(
|
||||
Conv((2, 2), 1=>16, relu),
|
||||
MaxPool((2,2)),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
|
@ -285,9 +285,9 @@ end
|
|||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test derivative(x -> mul(5, x), 3) == 5
|
||||
@test gradient(x -> mul(5, x), 3)[1] == 5
|
||||
@test count == 1
|
||||
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
||||
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue