work more nicely with TF batching
This commit is contained in:
parent
8335ab8134
commit
2609d47ce9
@ -17,7 +17,7 @@ cvalue(v::Vertex) = cvalue(value(v))
|
|||||||
graph(x::Tensor) = x
|
graph(x::Tensor) = x
|
||||||
|
|
||||||
# TODO: detect variable reuse
|
# TODO: detect variable reuse
|
||||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x')
|
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
||||||
|
|
||||||
function graph(model::Model, args...)
|
function graph(model::Model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
@ -30,13 +30,20 @@ function graph(model::Model, args...)
|
|||||||
end |> value
|
end |> value
|
||||||
end
|
end
|
||||||
|
|
||||||
graph(::typeof(*), args...) = *(reverse(args)...)
|
graph(::typeof(*), args...) = *(args...)
|
||||||
graph(::typeof(+), args...) = +(args...)
|
graph(::typeof(+), args...) = +(args...)
|
||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
# Treat the first dimension as the batch index
|
||||||
|
# TODO: custom data type for this
|
||||||
|
batch(x) = reshape(x, (1,size(x)...))
|
||||||
|
batch(xs...) = vcat(map(batch, xs)...)
|
||||||
|
|
||||||
|
unbatch(xs) = reshape(xs, size(xs)[2:end])
|
||||||
|
|
||||||
type Model
|
type Model
|
||||||
session::Session
|
session::Session
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
@ -56,7 +63,7 @@ end
|
|||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
run(m.session, m.graph, Dict(zip(m.inputs, map(transpose, args))))'
|
unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
function Flux.back!(m::Model, Δ, args...)
|
||||||
@ -79,7 +86,8 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
|
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
|
||||||
|
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
|
||||||
if i % 5000 == 0
|
if i % 5000 == 0
|
||||||
@show y
|
@show y
|
||||||
@show accuracy(m, test)
|
@show accuracy(m, test)
|
||||||
|
@ -5,14 +5,14 @@ export Dense
|
|||||||
@model type Dense
|
@model type Dense
|
||||||
W
|
W
|
||||||
b
|
b
|
||||||
x -> W*x + b
|
x -> x*W + b
|
||||||
end
|
end
|
||||||
|
|
||||||
Dense(in::Integer, out::Integer; init = initn) =
|
Dense(in::Integer, out::Integer; init = initn) =
|
||||||
Dense(init(out, in), init(out))
|
Dense(init(in, out), init(1, out))
|
||||||
|
|
||||||
Base.show(io::IO, d::Dense) =
|
Base.show(io::IO, d::Dense) =
|
||||||
print(io, "Dense($(size(d.W.x,2)),$(size(d.W.x,1)))")
|
print(io, "Dense($(size(d.W.x,1)),$(size(d.W.x,2)))")
|
||||||
|
|
||||||
@model type Sigmoid
|
@model type Sigmoid
|
||||||
layer::Model
|
layer::Model
|
||||||
|
Loading…
Reference in New Issue
Block a user