basic compile step

This commit is contained in:
Mike J Innes 2018-02-27 21:43:41 +00:00
parent bdb8aae107
commit 5a32976cbf
2 changed files with 25 additions and 1 deletions

View File

@ -1,5 +1,8 @@
# Primitive definitions # Primitive definitions
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
Shape{T}(size(A,1))
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
A_mul_B!(C, A, B) A_mul_B!(C, A, B)

View File

@ -2,7 +2,8 @@
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
using DataFlow using DataFlow
using DataFlow: inputnode, constant using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
@ -23,3 +24,23 @@ function trace(f, args...)
inputs = param.(args) inputs = param.(args)
graph(f(inputs...), inputs...) graph(f(inputs...), inputs...)
end end
# Graph manipulation
function cacheall(v, buf = () -> UInt8[])
prewalk(v) do v
iscall(v) && isconstant(v[1]) || return v
f = v[1].value.value
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
end
end
function eval_func(v, n)
v = vertex(Lambda(n, v))
v |> syntax |> eval
end
function compile(f, args...)
v = trace(f, args...)
eval_func(cacheall(v), length(args))
end