Merge #943
943: Fixes #900 r=MikeInnes a=dhairyagandhi96 Thoughts on the test? cc @MikeInnes Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
This commit is contained in:
commit
fb4a48f970
@ -6,7 +6,7 @@ using Base: tail
|
|||||||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using Zygote: Params, @adjoint, gradient, pullback
|
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||||
export gradient
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
|
@ -118,6 +118,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
|
||||||
|
@nograd conv_transpose_dims
|
||||||
|
|
||||||
function (c::ConvTranspose)(x::AbstractArray)
|
function (c::ConvTranspose)(x::AbstractArray)
|
||||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
using Flux, Test
|
using Flux, Test
|
||||||
using Flux: maxpool, meanpool
|
using Flux: maxpool, meanpool
|
||||||
|
using Flux: gradient
|
||||||
|
|
||||||
@testset "Pooling" begin
|
@testset "Pooling" begin
|
||||||
x = randn(Float32, 10, 10, 3, 2)
|
x = randn(Float32, 10, 10, 3, 2)
|
||||||
@ -54,6 +55,10 @@ end
|
|||||||
y = Conv((3,3), 1 => 1)(x)
|
y = Conv((3,3), 1 => 1)(x)
|
||||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||||
@test size(x_hat) == size(x)
|
@test size(x_hat) == size(x)
|
||||||
|
|
||||||
|
m = ConvTranspose((3,3), 1=>1)
|
||||||
|
# Test that the gradient call does not throw: #900
|
||||||
|
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "CrossCor" begin
|
@testset "CrossCor" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user