Merge branch 'master' into pull-request/07b0f95d
This commit is contained in:
commit
73a0be3e04
76
README.md
76
README.md
|
@ -1,11 +1,81 @@
|
|||
# Флукс
|
||||
<p align="center">
|
||||
<img width="200px" src="https://raw.githubusercontent.com/FluxML/fluxml.github.io/master/flux.png"/>
|
||||
</p>
|
||||
|
||||
[](https://travis-ci.org/FluxML/Flux.jl) [](https://fluxml.github.io/Flux.jl/stable/) [](https://slackinvite.julialang.org/)
|
||||
|
||||
Flux is a refreshing approach to machine learning. It provides lightweight abstractions on top of Julia's native GPU and AD support, while remaining fully hackable (right down to the [GPU kernels](https://github.com/FluxML/CuArrays.jl)).
|
||||
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
|
||||
|
||||
```julia
|
||||
julia> Pkg.add("Flux")
|
||||
```
|
||||
|
||||
See the [documentation](http://fluxml.github.io/Flux.jl/stable/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
See the [documentation](http://fluxml.github.io/Flux.jl/) or the [model zoo](https://github.com/FluxML/model-zoo/) for examples.
|
||||
|
||||
## Features
|
||||
|
||||
Flux has powerful high-level features, and common architectures can be defined in a few lines.
|
||||
|
||||
```julia
|
||||
model = Chain(
|
||||
Dense(768, 128),
|
||||
LSTM(128, 256)
|
||||
LSTM(256, 128)
|
||||
Dense(128, 10),
|
||||
softmax)
|
||||
|
||||
loss(x, y) = crossentropy(model(x), y)
|
||||
|
||||
Flux.train!(loss, data, ADAM(...))
|
||||
```
|
||||
|
||||
Yet you can easily strip away the layers, and directly write the mathematics for your problem. Flux will seamlessly take gradients of any Julia code, so your model looks just like the paper.
|
||||
|
||||
```julia
|
||||
W = param(randn(2, 10))
|
||||
b = param(randn(2))
|
||||
|
||||
y(x) = σ.(W * x .+ b)
|
||||
```
|
||||
|
||||
If that's *still* not enough, you can go as deep as you want, even writing your own CUDA kernels with [CUDAnative](https://github.com/JuliaGPU/CUDAnative.jl)! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
|
||||
|
||||
```julia
|
||||
function gpu_add(a, b, c)
|
||||
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
|
||||
c[i] = a[i] + b[i]
|
||||
return nothing
|
||||
end
|
||||
```
|
||||
|
||||
Unusual architectures are no problem in Flux, as you can use all the loops, control flow and even macros that you're used to. Here's a Tree RNN in 4 lines.
|
||||
|
||||
```julia
|
||||
tree() = rand() < 0.5 ? rand(10) : (tree(), tree()) # dummy data
|
||||
|
||||
shrink = Dense(20, 10)
|
||||
combine(a, b) = shrink([a; b])
|
||||
|
||||
model(x) = x
|
||||
model(x::Tuple) = combine(model(x[1]), model(x[2]))
|
||||
|
||||
model(tree()) # Sample output
|
||||
```
|
||||
|
||||
Despite this flexibility, Julia's advanced compiler lets us do some powerful optimisations. For example, this definition of `sigmoid` automatically gets fused into a *single* GPU kernel – so it's really fast.
|
||||
|
||||
```julia
|
||||
sigmoid(xs) = 1 ./ (1 .+ exp.(.-xs))
|
||||
```
|
||||
|
||||
Similarly, Flux is the first dynamic framework to support [compiling to the browser](https://fluxml.github.io/experiments/) and model import via [formats like ONNX](https://github.com/FluxML/ONNX.jl/), both of which are thinly-veiled compiler problems.
|
||||
|
||||
For more on our philosophy on machine learning, check out our article [On Machine Learning & Programming Languages](https://julialang.org/blog/2017/12/ml&pl).
|
||||
|
||||
## Contributing & Help
|
||||
|
||||
For general questions and help, check out Julia's [community forum](https://discourse.julialang.org/c/domain/ML).
|
||||
|
||||
Flux development is carried out via our [GitHub issues](https://github.com/FluxML/Flux.jl/issues), so feel free to open feature requests or PRs here.
|
||||
|
||||
For more informal discussions we'd love to have you on the [Julia slack](https://slackinvite.julialang.org/), where we hang out on the #machine-learning channel.
|
||||
|
|
|
@ -55,7 +55,7 @@ Tracked 5-element CuArray{Float32,1}:
|
|||
|
||||
The analogue `cpu` is also available for moving models and data back off of the GPU.
|
||||
|
||||
```
|
||||
```julia
|
||||
julia> x = rand(10) |> gpu
|
||||
10-element CuArray{Float32,1}:
|
||||
0.235164
|
||||
|
@ -67,4 +67,4 @@ julia> x |> cpu
|
|||
0.235164
|
||||
⋮
|
||||
0.192538
|
||||
```
|
||||
```
|
||||
|
|
|
@ -17,14 +17,6 @@
|
|||
url = {http://arxiv.org/abs/1712.03112},
|
||||
}
|
||||
|
||||
@online{CuArrays,
|
||||
author = {Mike Innes},
|
||||
title = {Generic GPU Kernels},
|
||||
year = 2017,
|
||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{MLPL,
|
||||
author = {Mike Innes and others},
|
||||
title = {On Machine Learning and Programming Languages},
|
||||
|
@ -33,10 +25,26 @@
|
|||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Fusion,
|
||||
author = {Steven G. Johnson},
|
||||
title = {More Dots: Syntactic Loop Fusion in Julia},
|
||||
@online{CuArrays,
|
||||
author = {Mike Innes and others},
|
||||
title = {Generic GPU Kernels},
|
||||
year = 2017,
|
||||
url = {https://julialang.org/blog/2017/01/moredots},
|
||||
url = {http://mikeinnes.github.io/2017/08/24/cudanative.html},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Zoo,
|
||||
author = {Mike Innes and others},
|
||||
title = {Flux Model Zoo},
|
||||
year = 2018,
|
||||
url = {https://github.com/FluxML/model-zoo/},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
||||
@online{Minibatch,
|
||||
author = {James Bradbury},
|
||||
title = {Minibatch.jl},
|
||||
year = 2018,
|
||||
url = {https://github.com/jekbradbury/Minibatch.jl},
|
||||
urldate = {2018-02-16}
|
||||
}
|
||||
|
|
|
@ -24,8 +24,8 @@ bibliography: paper.bib
|
|||
|
||||
Flux is library for machine learning (ML), written using the numerical computing language Julia [@Julia]. The package allows models to be written using Julia's simple mathematical syntax, and applies automatic differentiation (AD) to seamlessly calculate derivatives and train the model. Meanwhile, it makes heavy use of Julia's language and compiler features to carry out code analysis and make optimisations. For example, Julia's GPU compilation support [@besard:2017] can be used to JIT-compile custom GPU kernels for model layers [@CuArrays].
|
||||
|
||||
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. As a result of this approach, it already supports several features not available in any other dynamic framework, such as kernel fusion [@Fusion], memory usage optimisations, importing of models via ONNX, and deployment of models to JavaScript for running in the browser.
|
||||
The machine learning community has traditionally been divided between "static" and "dynamic" frameworks that are easy to optimise and easy to use, respectively [@MLPL]. Flux blurs the line between these two approaches, combining a highly intuitive programming model with the compiler techniques needed by ML. This enables research into advanced compiler transforms such as batching [@Minibatch] without changing any user code.
|
||||
|
||||
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics.
|
||||
Flux has been used heavily for natural language processing, but can also support state-of-the-art research models in areas like computer vision, reinforcement learning and robotics. Many examples of such models can be found in the model zoo [@Zoo].
|
||||
|
||||
# References
|
||||
|
|
|
@ -32,12 +32,10 @@ include("layers/stateless.jl")
|
|||
include("layers/basic.jl")
|
||||
include("layers/conv.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/normalise.jl")
|
||||
|
||||
include("data/Data.jl")
|
||||
|
||||
include("jit/JIT.jl")
|
||||
|
||||
@require CuArrays include("cuda/cuda.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -4,11 +4,4 @@ using CuArrays
|
|||
|
||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
||||
|
||||
import ..Flux.JIT: Shape, restructure
|
||||
|
||||
function restructure(sh::Shape{T}, buf::CuVector{UInt8}) where T
|
||||
buf = buf[1:sizeof(sh)]
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
end
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
module JIT
|
||||
|
||||
using MacroTools
|
||||
|
||||
include("shapes.jl")
|
||||
include("trace.jl")
|
||||
include("lib.jl")
|
||||
|
||||
end
|
|
@ -1,40 +0,0 @@
|
|||
# Primitive definitions
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||
Shape{T}(size(A,1))
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::MatShape{T}) where T =
|
||||
Shape{T}(size(A,1),size(B,2))
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
shape(::typeof(broadcast), f, xs...) =
|
||||
Shape{eltype(xs[1])}(Base.Broadcast.broadcast_shape(size.(xs)...)...)
|
||||
|
||||
inplace!(::typeof(broadcast), y, f, xs...) = broadcast!(f, y, xs...)
|
||||
|
||||
shape(::typeof(reshape), x::Shape{T}, i...) where T =
|
||||
Shape{T}(Base._reshape_uncolon(x, i))
|
||||
|
||||
inplace!(::typeof(reshape), y, x, i...) = copy!(y, x)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
using ..Tracker: _conv, _maxpool
|
||||
|
||||
shape(::typeof(softmax), x) = x
|
||||
inplace!(::typeof(softmax), y, x) = NNlib.softmax!(y, x)
|
||||
|
||||
shape(::typeof(_conv), x::Shape{T}, w::Shape{T}, stride, pad) where T =
|
||||
Shape{T}(NNlib.cdims(size(x), size(w), pad, stride))
|
||||
|
||||
inplace!(::typeof(_conv), y, x, w, stride, pad) =
|
||||
NNlib.conv!(y, x, w, stride = stride, pad = pad)
|
||||
|
||||
shape(::typeof(_maxpool), x::Shape{T}, k, pad) where T =
|
||||
Shape{T}(NNlib.pdims(size(x), k, pad, k))
|
||||
|
||||
inplace!(::typeof(_maxpool), y, x, k, pad) =
|
||||
NNlib.maxpool!(y, x, k, pad = pad)
|
|
@ -1,55 +0,0 @@
|
|||
using ..Tracker: TrackedArray
|
||||
|
||||
struct Shape{T,N}
|
||||
dims::NTuple{N,Int}
|
||||
end
|
||||
|
||||
VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
Base.length(s::Shape) = prod(s.dims)
|
||||
Base.eltype(s::Shape{T}) where T = T
|
||||
|
||||
Base.sizeof(s::Shape{T}) where T = sizeof(T)*prod(size(s))
|
||||
|
||||
function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, "Shape{$T}(")
|
||||
join(io, s.dims, ", ")
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = x
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
shape(x::TrackedArray) = shape(x.data)
|
||||
|
||||
bytes(s::Shape) = sizeof(s)
|
||||
bytes(x::Tuple) = sum(bytes.(x))
|
||||
|
||||
# Recover structure from byte buffers
|
||||
# Make sure to hold on to the parent buffer for the lifetime of the data.
|
||||
|
||||
function restructure(sh::Shape{T}, buf::Vector{UInt8}) where T
|
||||
buf = unsafe_wrap(Array, pointer(buf), sizeof(sh))
|
||||
reshape(reinterpret(T, buf), size(sh))
|
||||
end
|
||||
|
||||
# Execution with caches
|
||||
|
||||
mutable struct Cached{F,A}
|
||||
f::F
|
||||
buffer::A
|
||||
end
|
||||
|
||||
function (c::Cached)(args...)
|
||||
sh = shape(c.f, shape(args)...)
|
||||
bytes(sh) > length(c.buffer) && (c.buffer = similar(c.buffer, bytes(sh)))
|
||||
y = restructure(sh, c.buffer)
|
||||
inplace!(c.f, y, args...)
|
||||
end
|
|
@ -1,75 +0,0 @@
|
|||
# This is hacky; we'll eventually reuse Cassette for better tracing.
|
||||
|
||||
using ..Tracker, DataFlow
|
||||
using ..Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(y -> x === y, inputs)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
istracked(x) && !isleaf(x) ? graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
# Graph manipulation
|
||||
|
||||
function liftparams(v)
|
||||
ps = []
|
||||
v = prewalk(DataFlow.bumpinputs(v)) do v
|
||||
isconstant(v) && istracked(v.value.value) || return v
|
||||
push!(ps, v.value.value)
|
||||
DataFlow.vcall(getindex, inputnode(1), length(ps))
|
||||
end
|
||||
return v, ps
|
||||
end
|
||||
|
||||
function cacheall(v, buf = () -> UInt8[])
|
||||
prewalk(v) do v
|
||||
iscall(v) && isconstant(v[1]) || return v
|
||||
f = v[1].value.value
|
||||
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
|
||||
end
|
||||
end
|
||||
|
||||
code(v, n) = syntax(vertex(Lambda(n, v)))
|
||||
|
||||
struct Compiled{F,T<:Tuple}
|
||||
model
|
||||
func::F
|
||||
params::T
|
||||
end
|
||||
|
||||
# TODO when we support derivatives
|
||||
# (c::Compiled)(args...) =
|
||||
# Tracker.track(Tracker.Call(c, args...),
|
||||
# c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
||||
|
||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(cacheall(v, () -> similar(args[1], UInt8, 1))) # no empty arrays on GPU
|
||||
Compiled(f, eval(code(v, length(args)+1)), (ps...,))
|
||||
end
|
||||
|
||||
function source(f, args...)
|
||||
v = trace(f, args...)
|
||||
v, ps = liftparams(v)
|
||||
code(v, length(args)+1) |> prettify
|
||||
end
|
|
@ -96,6 +96,18 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
|||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
for (i,v) in enumerate(Δ)
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
start = 0
|
||||
|
@ -139,6 +151,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
|||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
|
@ -242,21 +261,21 @@ function back(::typeof(_conv), Δ, x, w, stride, pad)
|
|||
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad))
|
||||
end
|
||||
|
||||
_maxpool(x, k, pad) = maxpool(x, k; pad = pad)
|
||||
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_maxpool, x, k, pad)
|
||||
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_maxpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
_meanpool(x, k, pad) = meanpool(x, k; pad = pad)
|
||||
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
|
||||
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k)) =
|
||||
track(_meanpool, x, k, pad)
|
||||
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
|
||||
track(_meanpool, x, k, pad, stride)
|
||||
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad))
|
||||
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
|
||||
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
|
||||
|
||||
# Broadcasting
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ accum!(x, Δ) = x .+ Δ
|
|||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
function back(x::Tracked, Δ)
|
||||
x.isleaf && (accum!(x.grad, Δ); return)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
|
|
12
test/jit.jl
12
test/jit.jl
|
@ -1,12 +0,0 @@
|
|||
using Flux, Base.Test
|
||||
using Flux.JIT: compile
|
||||
|
||||
@testset "JIT" begin
|
||||
|
||||
m = Dense(10, 5)
|
||||
f = compile(m, rand(10))
|
||||
x = rand(10)
|
||||
|
||||
@test Tracker.data(m(x)) == f(x)
|
||||
|
||||
end
|
|
@ -10,7 +10,6 @@ include("layers/normalisation.jl")
|
|||
include("layers/stateless.jl")
|
||||
include("optimise.jl")
|
||||
include("data.jl")
|
||||
include("jit.jl")
|
||||
|
||||
if Base.find_in_path("CuArrays") ≠ nothing
|
||||
include("cuda/cuda.jl")
|
||||
|
|
|
@ -16,6 +16,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
@ -32,6 +34,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(kron,rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron,rand(5,1), rand(3,1))
|
||||
|
@ -99,4 +104,8 @@ end
|
|||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
@test Tracker.grad(b) == 1
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue