no grad dims helper

This commit is contained in:
Dhairya Gandhi 2019-11-24 13:25:02 +05:30
parent 5839e166f6
commit 5f21238d1a
3 changed files with 9 additions and 1 deletions

View File

@ -6,7 +6,7 @@ using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -118,6 +118,8 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
) )
end end
@nograd conv_transpose_dims
function (c::ConvTranspose)(x::AbstractArray) function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)

View File

@ -1,5 +1,6 @@
using Flux, Test using Flux, Test
using Flux: maxpool, meanpool using Flux: maxpool, meanpool
using Flux: gradient
@testset "Pooling" begin @testset "Pooling" begin
x = randn(Float32, 10, 10, 3, 2) x = randn(Float32, 10, 10, 3, 2)
@ -54,6 +55,11 @@ end
y = Conv((3,3), 1 => 1)(x) y = Conv((3,3), 1 => 1)(x)
x_hat = ConvTranspose((3, 3), 1 => 1)(y) x_hat = ConvTranspose((3, 3), 1 => 1)(y)
@test size(x_hat) == size(x) @test size(x_hat) == size(x)
m = ConvTranspose((3,3), 2=>1)
x = rand(10,10,2,1)
# Test that the gradient call does not throw: #900
@test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads
end end
@testset "CrossCor" begin @testset "CrossCor" begin