Merge #943
943: Fixes #900 r=MikeInnes a=dhairyagandhi96 Thoughts on the test? cc @MikeInnes Co-authored-by: Dhairya Gandhi <dhairya@juliacopmuting.com>
This commit is contained in:
commit
fb4a48f970
|
@ -6,7 +6,7 @@ using Base: tail
|
|||
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient, pullback
|
||||
using Zygote: Params, @adjoint, gradient, pullback, @nograd
|
||||
export gradient
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
|
|
|
@ -118,6 +118,9 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
|
|||
)
|
||||
end
|
||||
|
||||
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
|
||||
@nograd conv_transpose_dims
|
||||
|
||||
function (c::ConvTranspose)(x::AbstractArray)
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Flux, Test
|
||||
using Flux: maxpool, meanpool
|
||||
using Flux: gradient
|
||||
|
||||
@testset "Pooling" begin
|
||||
x = randn(Float32, 10, 10, 3, 2)
|
||||
|
@ -54,6 +55,10 @@ end
|
|||
y = Conv((3,3), 1 => 1)(x)
|
||||
x_hat = ConvTranspose((3, 3), 1 => 1)(y)
|
||||
@test size(x_hat) == size(x)
|
||||
|
||||
m = ConvTranspose((3,3), 1=>1)
|
||||
# Test that the gradient call does not throw: #900
|
||||
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
|
||||
end
|
||||
|
||||
@testset "CrossCor" begin
|
||||
|
|
Loading…
Reference in New Issue