add weight and bias kwargs
This commit is contained in:
parent
1fe321781b
commit
55ef7c1aba
|
@ -32,19 +32,32 @@ struct Conv{N,M,F,A,V}
|
|||
dilation::NTuple{N,Int}
|
||||
end
|
||||
|
||||
function Conv(w::AbstractArray{T,N}, b::Union{Nothing, ZeroType, AbstractVector{T}}, σ = identity;
|
||||
"""
|
||||
Conv(weight::AbstractArray, bias::AbstractArray)
|
||||
Conv(weight::AbstractArray, bias::AbstractArray, relu)
|
||||
|
||||
Constructs the convolutional layer with user defined weight and bias arrays.
|
||||
All other behaviours of the Conv layer apply with regard to data order and
|
||||
forward pass.
|
||||
|
||||
Takes the keyword arguments `pad`, `stride` and `dilation`.
|
||||
"""
|
||||
function Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
||||
stride = 1, pad = 0, dilation = 1) where {T,N}
|
||||
stride = expand(Val(N-2), stride)
|
||||
pad = expand(Val(2*(N-2)), pad)
|
||||
dilation = expand(Val(N-2), dilation)
|
||||
b = b isa Nothing ? ZeroType((size(w, ndims(w)), )) : b
|
||||
return Conv(σ, w, b, stride, pad, dilation)
|
||||
end
|
||||
|
||||
convweight(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init = glorot_uniform) = init(k..., ch...)
|
||||
const convbias = zeros
|
||||
|
||||
function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, use_bias = true) where N
|
||||
b = use_bias ? zeros(ch[2]) : ZeroType((ch[2],))
|
||||
Conv(init(k..., ch...), b, σ,
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = convweight(k, ch, init = init), bias = convbias(ch[2])) where N
|
||||
|
||||
Conv(weight, bias, σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
||||
|
|
|
@ -139,15 +139,6 @@ function throttle(f, timeout; leading=true, trailing=false)
|
|||
end
|
||||
end
|
||||
|
||||
import Base: +, reshape, size
|
||||
struct ZeroType{T} <: Number
|
||||
size::T
|
||||
end
|
||||
+(a::Number, ::ZeroType) = a
|
||||
+(::ZeroType, a::Number) = a
|
||||
size(xs::ZeroType) = xs.size
|
||||
reshape(::ZeroType, args...) = ZeroType(args)
|
||||
@adjoint reshape(xs::ZeroType, dims...) = ZeroType(dims), Δ -> (ZeroType(size(xs)), map(_ -> nothing, dims)...)
|
||||
|
||||
"""
|
||||
@jit ...
|
||||
|
|
|
@ -28,11 +28,7 @@ end
|
|||
op = bias(ip)
|
||||
@test sum(op) == prod(size(op))
|
||||
|
||||
bias = Conv(ones(Float32, 2, 2, 1, 3), Flux.ZeroType((3,)))
|
||||
op = bias(ip)
|
||||
@test sum(op) === 0.f0
|
||||
|
||||
bias = Conv(ones(Float32, 2, 2, 1, 3), nothing)
|
||||
bias = Conv((2,2), 1=>3, bias = zero(3))
|
||||
op = bias(ip)
|
||||
@test sum(op) === 0.f0
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue