Merge #655
655: Added support for Float64 for DepthwiseConv r=dhairyagandhi96 a=thebhatman DepthwiseConv was giving errors for Float64. This fixes the issue. Co-authored-by: Manjunath Bhat <manjunathbhat9920@gmail.com>
This commit is contained in:
commit
bd9d73a941
@ -165,6 +165,12 @@ function Base.show(io::IO, l::DepthwiseConv)
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
invoke(a, Tuple{AbstractArray}, x)
|
||||||
|
|
||||||
|
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
|
||||||
|
a(T.(x))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
MaxPool(k)
|
MaxPool(k)
|
||||||
|
|
||||||
|
@ -32,4 +32,14 @@ end
|
|||||||
m2 = DepthwiseConv((2, 2), 3)
|
m2 = DepthwiseConv((2, 2), 3)
|
||||||
|
|
||||||
@test size(m2(r), 3) == 3
|
@test size(m2(r), 3) == 3
|
||||||
|
|
||||||
|
x = zeros(Float64, 28, 28, 3, 5)
|
||||||
|
|
||||||
|
m3 = DepthwiseConv((2, 2), 3 => 5)
|
||||||
|
|
||||||
|
@test size(m3(r), 3) == 15
|
||||||
|
|
||||||
|
m4 = DepthwiseConv((2, 2), 3)
|
||||||
|
|
||||||
|
@test size(m4(r), 3) == 3
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user