diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 716bc574..599776ce 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,5 +1,7 @@ # TODO: broadcasting cat -combine(x, h) = vcat(x, h .* trues(1, size(x, 2))) +combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2))) +combine(x::AbstractVector, h::AbstractVector) = vcat(x, h) +combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h) # Stateful recurrence