better handling for reused params
This commit is contained in:
parent
740d868ef9
commit
d6eacf3375
@ -1,3 +1,4 @@
|
|||||||
|
import Base: @get!
|
||||||
import Flow: Constant, postwalk, value, inputs, constant
|
import Flow: Constant, postwalk, value, inputs, constant
|
||||||
import TensorFlow: RawTensor
|
import TensorFlow: RawTensor
|
||||||
|
|
||||||
@ -21,14 +22,14 @@ graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
|||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
graph(c::Conv2D, x) =
|
|
||||||
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
|
|
||||||
|
|
||||||
graph(p::MaxPool, x) =
|
graph(p::MaxPool, x) =
|
||||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||||
|
|
||||||
graph(::Flow.Group, xs...) = (xs...,)
|
graph(::Flow.Group, xs...) = (xs...,)
|
||||||
|
|
||||||
|
graph(params::Associative, c::Conv2D, x) =
|
||||||
|
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
||||||
|
|
||||||
type Op
|
type Op
|
||||||
f
|
f
|
||||||
shape
|
shape
|
||||||
@ -40,22 +41,29 @@ graph(op::Op, xs...) = op.f(xs...)
|
|||||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||||
|
|
||||||
# TODO: detect variable reuse
|
# TODO: detect variable reuse
|
||||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
graph{T<:AArray}(params::Associative, p::Flux.Param{T}) =
|
||||||
|
@get!(params, p, Variable(p.x))
|
||||||
|
|
||||||
function graph(v::IVertex, args...)
|
function graph(params::Associative, v::IVertex, args...)
|
||||||
# TODO: check number of arguments
|
# TODO: check number of arguments
|
||||||
v = spliceinputs(v, map(constant, args)...) |> detuple
|
v = spliceinputs(v, map(constant, args)...) |> detuple
|
||||||
postwalk(v) do v
|
postwalk(v) do v
|
||||||
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
vertex(graph(params, cvalue(v), cvalue.(inputs(v))...))
|
||||||
end |> value
|
end |> value
|
||||||
end
|
end
|
||||||
|
|
||||||
function graph(model::Flux.Model, args...)
|
function graph(params::Associative, model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g ≠ nothing || error("No graph for $model")
|
g == nothing && return graph(model, args...)
|
||||||
graph(g, args...)
|
graph(params, g, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
function tograph(model, args...)
|
||||||
|
params = Dict{Flux.Param,Tensor}()
|
||||||
|
g = graph(params, model, args...)
|
||||||
|
return params, g
|
||||||
|
end
|
||||||
|
|
||||||
|
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
type Model
|
type Model
|
||||||
model
|
model
|
||||||
session::Session
|
session::Session
|
||||||
vars::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
inputs::Vector{Tensor}
|
inputs::Vector{Tensor}
|
||||||
outputs::Vector{Tensor}
|
outputs::Vector{Tensor}
|
||||||
gradients::Vector{Tensor}
|
gradients::Vector{Tensor}
|
||||||
@ -9,11 +9,12 @@ end
|
|||||||
|
|
||||||
function tf(model)
|
function tf(model)
|
||||||
sess = Session()
|
sess = Session()
|
||||||
vars = Dict{Flux.Param,Tensor}()
|
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
output = graph(model, input)
|
params, output = tograph(model, input)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
Model(model, sess, vars, [input], [output], [gradients(output, input)])
|
Model(model, sess, params,
|
||||||
|
[input], [output],
|
||||||
|
[gradients(output, input)])
|
||||||
end
|
end
|
||||||
|
|
||||||
batch(x) = Batch((x,))
|
batch(x) = Batch((x,))
|
||||||
|
Loading…
Reference in New Issue
Block a user