remove diff code for now
This commit is contained in:
parent
f3555a9c57
commit
de6c3ef07e
@ -13,7 +13,6 @@ include("model.jl")
|
|||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
|
||||||
include("compiler/diff.jl")
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
include("compiler/interp.jl")
|
include("compiler/interp.jl")
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
addΔ(a, b) = vertex(:+, a, b)
|
|
||||||
|
|
||||||
# Special case a couple of operators to clean up output code
|
|
||||||
const symbolic = Dict()
|
|
||||||
|
|
||||||
symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
|
||||||
|
|
||||||
function ∇v(v::Vertex, Δ)
|
|
||||||
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
|
||||||
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
|
||||||
map(i -> @flow(getindex($Δ, $i)), 1:DataFlow.nin(v))
|
|
||||||
end
|
|
||||||
|
|
||||||
function invert(v::IVertex, Δ = :Δ, out = d())
|
|
||||||
@assert !iscyclic(v)
|
|
||||||
if isconstant(v)
|
|
||||||
@assert !haskey(out, value(v))
|
|
||||||
out[value(v).value] = constant(Δ)
|
|
||||||
else
|
|
||||||
Δ′s = ∇v(v, Δ)
|
|
||||||
for (v′, Δ′) in zip(inputs(v), Δ′s)
|
|
||||||
invert(v′, Δ′, out)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return out
|
|
||||||
end
|
|
||||||
|
|
||||||
back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
|
|
||||||
|
|
||||||
back!(::typeof(*), Δ, a, b) = Δ*b', a'*Δ
|
|
Loading…
Reference in New Issue
Block a user