diff --git a/src/jit/lib.jl b/src/jit/lib.jl index cc89fa00..5a9fe776 100644 --- a/src/jit/lib.jl +++ b/src/jit/lib.jl @@ -1,5 +1,8 @@ # Primitive definitions +shape(::typeof(*), A::MatShape{T}, B::VecShape{T}) where T = + Shape{T}(size(A,1)) + inplace!(::typeof(*), C::AbstractArray, A::AbstractMatrix, B::AbstractArray) = A_mul_B!(C, A, B) diff --git a/src/jit/trace.jl b/src/jit/trace.jl index 5557741e..b33076d3 100644 --- a/src/jit/trace.jl +++ b/src/jit/trace.jl @@ -2,7 +2,8 @@ using ..Flux.Tracker: Tracked, Broadcasted, param, tracker, istracked, isleaf using DataFlow -using DataFlow: inputnode, constant +using DataFlow: Call, Lambda, iscall, isconstant, prewalk, vertex, syntax, + inputnode, constant vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) @@ -23,3 +24,23 @@ function trace(f, args...) inputs = param.(args) graph(f(inputs...), inputs...) end + +# Graph manipulation + +function cacheall(v, buf = () -> UInt8[]) + prewalk(v) do v + iscall(v) && isconstant(v[1]) || return v + f = v[1].value.value + return vertex(Call(), constant(Cached(f, buf())), v[2:end]...) + end +end + +function eval_func(v, n) + v = vertex(Lambda(n, v)) + v |> syntax |> eval +end + +function compile(f, args...) + v = trace(f, args...) + eval_func(cacheall(v), length(args)) +end