work more nicely with TF batching

This commit is contained in:
Mike J Innes 2016-10-04 21:10:50 +01:00
parent 8335ab8134
commit 2609d47ce9
2 changed files with 15 additions and 7 deletions

View File

@ -17,7 +17,7 @@ cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x graph(x::Tensor) = x
# TODO: detect variable reuse # TODO: detect variable reuse
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x') graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
function graph(model::Model, args...) function graph(model::Model, args...)
g = Flux.graph(model) g = Flux.graph(model)
@ -30,13 +30,20 @@ function graph(model::Model, args...)
end |> value end |> value
end end
graph(::typeof(*), args...) = *(reverse(args)...) graph(::typeof(*), args...) = *(args...)
graph(::typeof(+), args...) = +(args...) graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x) graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x) graph(::typeof(relu), x) = nn.relu(x)
graph(::Input, x) = x graph(::Input, x) = x
# Treat the first dimension as the batch index
# TODO: custom data type for this
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
unbatch(xs) = reshape(xs, size(xs)[2:end])
type Model type Model
session::Session session::Session
inputs::Vector{Tensor} inputs::Vector{Tensor}
@ -56,7 +63,7 @@ end
function (m::Model)(args...) function (m::Model)(args...)
@assert length(args) == length(m.inputs) @assert length(args) == length(m.inputs)
run(m.session, m.graph, Dict(zip(m.inputs, map(transpose, args))))' unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
end end
function Flux.back!(m::Model, Δ, args...) function Flux.back!(m::Model, Δ, args...)
@ -79,7 +86,8 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
for e in 1:epoch for e in 1:epoch
info("Epoch $e\n") info("Epoch $e\n")
@progress for (x, y) in train @progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y')) y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
if i % 5000 == 0 if i % 5000 == 0
@show y @show y
@show accuracy(m, test) @show accuracy(m, test)

View File

@ -5,14 +5,14 @@ export Dense
@model type Dense @model type Dense
W W
b b
x -> W*x + b x -> x*W + b
end end
Dense(in::Integer, out::Integer; init = initn) = Dense(in::Integer, out::Integer; init = initn) =
Dense(init(out, in), init(out)) Dense(init(in, out), init(1, out))
Base.show(io::IO, d::Dense) = Base.show(io::IO, d::Dense) =
print(io, "Dense($(size(d.W.x,2)),$(size(d.W.x,1)))") print(io, "Dense($(size(d.W.x,1)),$(size(d.W.x,2)))")
@model type Sigmoid @model type Sigmoid
layer::Model layer::Model