add weight and bias kwargs

This commit is contained in:
Dhairya Gandhi 2019-10-06 04:25:23 +05:30
parent 1fe321781b
commit 55ef7c1aba
3 changed files with 19 additions and 19 deletions

View File

@ -32,19 +32,32 @@ struct Conv{N,M,F,A,V}
dilation::NTuple{N,Int}
end
function Conv(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
"""
Conv(weight::AbstractArray, bias::AbstractArray)
Conv(weight::AbstractArray, bias::AbstractArray, relu)
Constructs the convolutional layer with user defined weight and bias arrays.
All other behaviours of the Conv layer apply with regard to data order and
forward pass.
Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
return Conv(σ, w, b, stride, pad, dilation)
end
convweight(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform) = init(k..., ch...)
const convbias = zeros
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, use_bias = true) where N
b = use_bias ? zeros(ch[2]) : ZeroType((ch[2],))
Conv(init(k..., ch...), b, σ,
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
weight = convweight(k, ch, init = init), bias = convbias(ch[2])) where N
Conv(weight, bias, σ,
stride = stride, pad = pad, dilation = dilation)
end

View File

@ -139,15 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
end
end
import Base: +, reshape, size
struct ZeroType{T} <: Number
size::T
end
+(a::Number, ::ZeroType) = a
+(::ZeroType, a::Number) = a
size(xs::ZeroType) = xs.size
reshape(::ZeroType, args...) = ZeroType(args)
@adjoint reshape(xs::ZeroType, dims...) = ZeroType(dims), Δ -> (ZeroType(size(xs)), map(_ -> nothing, dims)...)
"""
@jit ...

View File

@ -28,11 +28,7 @@ end
op = bias(ip)
@test sum(op) == prod(size(op))
bias = Conv(ones(Float32, 2, 2, 1, 3), Flux.ZeroType((3,)))
op = bias(ip)
@test sum(op) === 0.f0
bias = Conv(ones(Float32, 2, 2, 1, 3), nothing)
bias = Conv((2,2), 1=>3, bias = zero(3))
op = bias(ip)
@test sum(op) === 0.f0
end