use broadcasting plus
This commit is contained in:
parent
a5bd72753e
commit
90edefe072
|
@ -18,7 +18,7 @@ node(x::Tuple) = map(node, x)
|
|||
node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
||||
|
|
|
@ -19,7 +19,7 @@ graph(::typeof(σ), x) = nn.sigmoid(x)
|
|||
graph(::typeof(hcat), xs...) = concat(1, xs)
|
||||
graph(::typeof(seq), xs, n) = TensorFlow.unpack(xs, num = n, axis = 1)
|
||||
|
||||
for op in (tanh, *, .*, +, -)
|
||||
for op in (tanh, *, .*, .+, .-)
|
||||
@eval graph(::typeof($op), args...) = $op(args...)
|
||||
end
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
|||
(a[1], b[2])
|
||||
end
|
||||
|
||||
infer(::typeof(+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||
infer(::typeof(.+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||
|
||||
# Shapes macro
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ export Affine
|
|||
@net type Affine
|
||||
W
|
||||
b
|
||||
x -> x*W + b
|
||||
x -> x*W .+ b
|
||||
end
|
||||
|
||||
Affine(in::Integer, out::Integer; init = initn) =
|
||||
|
|
|
@ -4,7 +4,7 @@ export Recurrent, GatedRecurrent, LSTM
|
|||
Wxy; Wyy; by
|
||||
y
|
||||
function (x)
|
||||
y = tanh( x * Wxy + y{-1} * Wyy + by )
|
||||
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -17,10 +17,10 @@ Recurrent(in, out; init = initn) =
|
|||
Wxh; Wyh; bh
|
||||
y
|
||||
function (x)
|
||||
reset = σ( x * Wxr + y{-1} * Wyr + br )
|
||||
update = σ( x * Wxu + y{-1} * Wyu + bu )
|
||||
y′ = tanh( x * Wxh + (reset .* y{-1}) * Wyh + bh )
|
||||
y = (1 .- update) .* y′ + update .* y{-1}
|
||||
reset = σ( x * Wxr .+ y{-1} * Wyr .+ br )
|
||||
update = σ( x * Wxu .+ y{-1} * Wyu .+ bu )
|
||||
y′ = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh )
|
||||
y = (1 .- update) .* y′ .+ update .* y{-1}
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -36,12 +36,12 @@ GatedRecurrent(in, out; init = initn) =
|
|||
y; state
|
||||
function (x)
|
||||
# Gates
|
||||
forget = σ( x * Wxf + y{-1} * Wyf + bf )
|
||||
input = σ( x * Wxi + y{-1} * Wyi + bi )
|
||||
output = σ( x * Wxo + y{-1} * Wyo + bo )
|
||||
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
|
||||
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
|
||||
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
|
||||
# State update and output
|
||||
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
|
||||
state = forget .* state{-1} + input .* state′
|
||||
state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
|
||||
state = forget .* state{-1} .+ input .* state′
|
||||
y = output .* tanh(state)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -28,7 +28,8 @@ d1 = @net x -> x * d.W + d.b
|
|||
@test d(xs) == d1(xs)
|
||||
|
||||
let
|
||||
@capture(syntax(d), _Frame(_Line(x_[1] * W_ + b_)))
|
||||
# In 0.6 `.+` evaluates to an anon function, so we must match on that.
|
||||
@capture(syntax(d), _Frame(_Line(bplus_(x_[1] * W_, b_))))
|
||||
@test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue