basic compile step
This commit is contained in:
parent
bdb8aae107
commit
5a32976cbf
@ -1,5 +1,8 @@
|
||||
# Primitive definitions
|
||||
|
||||
shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T =
|
||||
Shape{T}(size(A,1))
|
||||
|
||||
inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) =
|
||||
A_mul_B!(C, A, B)
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax,
|
||||
inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
@ -23,3 +24,23 @@ function trace(f, args...)
|
||||
inputs = param.(args)
|
||||
graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
# Graph manipulation
|
||||
|
||||
function cacheall(v, buf = () -> UInt8[])
|
||||
prewalk(v) do v
|
||||
iscall(v) && isconstant(v[1]) || return v
|
||||
f = v[1].value.value
|
||||
return vertex(Call(), constant(Cached(f, buf())), v[2:end]...)
|
||||
end
|
||||
end
|
||||
|
||||
function eval_func(v, n)
|
||||
v = vertex(Lambda(n, v))
|
||||
v |> syntax |> eval
|
||||
end
|
||||
|
||||
function compile(f, args...)
|
||||
v = trace(f, args...)
|
||||
eval_func(cacheall(v), length(args))
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user