Fix conflict

This commit is contained in:
Avik Pal 2018-11-14 22:18:57 +05:30
commit dfd680646c
13 changed files with 87 additions and 58 deletions

View File

@ -6,9 +6,9 @@ version = "0.2.0"
[[Adapt]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "a1245c11af6876245c32f82f2067bf67f7da8cee"
git-tree-sha1 = "04d15700419b6949d76be1428ab6e0277ff43b06"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "0.4.0"
version = "0.4.1"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -26,10 +26,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.2"
[[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "Pkg", "Test", "TranscodingStreams"]
git-tree-sha1 = "83cb3d65c37ea1364c2d5bf7bcea41843ba645dc"
deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"]
git-tree-sha1 = "e3df104c84dfc108f0ca203fd7f5bbdc98641ae9"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.5.0"
version = "0.5.1"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random", "Test"]
@ -38,7 +38,7 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.7.5"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Pkg", "Printf", "Reexport", "Test"]
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"]
git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.9.5"
@ -56,7 +56,7 @@ uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "1.3.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "REPL", "Random", "Serialization", "Test"]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "8fc6e166e24fda04b2b648d4260cdad241788c54"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.14.0"
@ -86,16 +86,16 @@ deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FixedPointNumbers]]
deps = ["Pkg", "Test"]
deps = ["Test"]
git-tree-sha1 = "b8045033701c3b10bf2324d7203404be7aef88ba"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.5.3"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Pkg", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "d8f3e0f19d0d546aa92eb1cd67cd3e515768d9f7"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
git-tree-sha1 = "b91250044374764e7c29af59a774c4b8d6100b6e"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.0"
version = "0.10.1"
[[InteractiveUtils]]
deps = ["LinearAlgebra", "Markdown"]
@ -132,9 +132,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "9f390271c9a43dcbe908a10b5b9632cf58cbab5b"
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.4.1"
version = "0.5.0"
[[Missings]]
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
@ -158,7 +158,7 @@ uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.2"
[[OrderedCollections]]
deps = ["Pkg", "Random", "Serialization", "Test"]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.0.2"
@ -220,12 +220,12 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "c35c9c76008babf4d658060fc64aeb369a41e7bd"
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.1"
version = "0.7.2"
[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Pkg", "Random", "Statistics", "Test"]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "ebc5c2a27d91d5ec611a9861168182e2168effd3"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.9.2"
@ -245,7 +245,7 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TranscodingStreams]]
deps = ["DelimitedFiles", "Pkg", "Random", "Test"]
deps = ["Pkg", "Random", "Test"]
git-tree-sha1 = "a34a2d588e2d2825602bf14a24216d5c8b0921ec"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.8.1"

View File

@ -3,7 +3,7 @@ Juno
MacroTools 0.3.3
NNlib
Requires
Adapt
Adapt 0.4
CodecZlib
Colors
ZipFile

View File

@ -2,9 +2,20 @@ module CUDA
using ..CuArrays
if !applicable(CuArray{UInt8}, undef, 1)
(T::Type{<:CuArray})(::UndefInitializer, sz...) = T(sz...)
end
if CuArrays.libcudnn != nothing
include("curnn.jl")
include("cudnn.jl")
if isdefined(CuArrays, :libcudnn_handle)
handle() = CuArrays.libcudnn_handle[]
else
handle() = CuArrays.CUDNN.handle()
end
include("curnn.jl")
include("cudnn.jl"
else
@warn("CUDNN is not installed, some functionality will not be available.")
end
end

View File

@ -1,6 +1,7 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
using LinearAlgebra
mutable struct DropoutDesc
ptr::Ptr{Nothing}
@ -14,7 +15,7 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
s = Csize_t[0]
@check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Nothing}},), d)
@check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Csize_t}),handle(),s)
states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0?
states = CuArray{UInt8}(undef, s[]) # TODO: can we drop this when ρ=0?
desc = DropoutDesc(d[], states)
@check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},Ptr{Nothing},Cfloat,Ptr{Nothing},Csize_t,Culonglong),
desc,handle(),ρ,states,length(states),seed)

View File

@ -1,6 +1,5 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, handle, cudnnDataType, TensorDesc, FilterDesc
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
@ -80,12 +79,12 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
return Int(size[])
end
const workspace = [CuVector{UInt8}(1)]
const workspace = [CuVector{UInt8}(undef, 1)]
getworkspace(bytes) =
length(workspace[]) bytes ?
workspace[] :
(workspace[] = CuVector{UInt8}(bytes))
(workspace[] = CuVector{UInt8}(undef, bytes))
getworkspace(r::RNNDesc, seqlen, xdesc) =
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
@ -147,7 +146,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, t
ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength,
@ -232,9 +231,6 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}

View File

@ -26,7 +26,6 @@ end
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
@ -114,3 +113,11 @@ end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Dense{<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))

View File

@ -37,7 +37,7 @@ Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
@treelike Conv
function (c::Conv)(x)
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
@ -51,6 +51,12 @@ function Base.show(io::IO, l::Conv)
print(io, ")")
end
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
(a::Conv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
"""
DepthwiseConv(size, in)
DepthwiseConv(size, in=>mul)

View File

@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
import Adapt.adapt
import Adapt: adapt, adapt_structure
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
import .CuArrays: CuArray, cudaconvert

View File

@ -5,7 +5,8 @@ using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
tracker(x) = nothing
@ -99,7 +100,8 @@ end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
@ -108,8 +110,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import Adapt.adapt
import Adapt: adapt, adapt_structure
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end

View File

@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
return
end
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients
struct Params
@ -162,20 +171,11 @@ function losscheck(x)
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)

View File

@ -1,6 +1,12 @@
# Arrays
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))

View File

@ -2,7 +2,7 @@ using Flux, Test
using Flux: maxpool, meanpool
@testset "Pooling" begin
x = randn(10, 10, 3, 2)
x = randn(Float32, 10, 10, 3, 2)
mp = MaxPool((2, 2))
@test mp(x) == maxpool(x, (2,2))
mp = MeanPool((2, 2))
@ -10,7 +10,7 @@ using Flux: maxpool, meanpool
end
@testset "CNN" begin
r = zeros(28, 28, 1, 5)
r = zeros(Float32, 28, 28, 1, 5)
m = Chain(
Conv((2, 2), 1=>16, relu),
MaxPool((2,2)),

View File

@ -1,6 +1,6 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
@ -285,9 +285,9 @@ end
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end