split out core
This commit is contained in:
parent
9a460e12f2
commit
2717ace397
@ -20,9 +20,12 @@ export @net, unroll, unroll1, @shapes,
|
|||||||
include("Batches/Batches.jl")
|
include("Batches/Batches.jl")
|
||||||
using .Batches
|
using .Batches
|
||||||
|
|
||||||
|
include("core.jl")
|
||||||
|
import .FluxCore: back!, update!, graph
|
||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
include("model.jl")
|
include("params.jl")
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
|
@ -25,3 +25,13 @@ function interpmodel(m, args...)
|
|||||||
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||||
@ithrow interp(ctx, m, args...)
|
@ithrow interp(ctx, m, args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Anonymous models
|
||||||
|
|
||||||
|
struct Capacitor
|
||||||
|
graph::IVertex{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||||
|
|
||||||
|
graph(cap::Capacitor) = cap.graph
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
|
# This code is in a submodule with the intention that it will be split into an
|
||||||
|
# interface package.
|
||||||
|
|
||||||
|
module FluxCore
|
||||||
|
|
||||||
"""
|
"""
|
||||||
back!(model, ΔY, X...) => ΔX
|
back!(model, ΔY, X...) => ΔX
|
||||||
|
|
||||||
@ -27,57 +32,4 @@ methods as necessary.
|
|||||||
"""
|
"""
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
|
|
||||||
# TODO: should be AbstractArray?
|
|
||||||
"""
|
|
||||||
A `Param` object stores a parameter array along with its gradient.
|
|
||||||
When converting to backends like TensorFlow, identical `Param`s will
|
|
||||||
result in identical variable objects.
|
|
||||||
"""
|
|
||||||
struct Param{T}
|
|
||||||
x::T
|
|
||||||
Δx::T
|
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
|
||||||
param(x::T) => ::Param{T}
|
|
||||||
|
|
||||||
Convenience method for creating a `Param` object for a given array.
|
|
||||||
"""
|
|
||||||
param(x) = Param(x, zero(x))
|
|
||||||
|
|
||||||
state(p::Param) = p.x
|
|
||||||
|
|
||||||
"""
|
|
||||||
update!(p::Param)
|
|
||||||
|
|
||||||
Apply the accumulated updates to the value of the parameter.
|
|
||||||
"""
|
|
||||||
function update!(p::Param, η)
|
|
||||||
p.x .-= p.Δx .* η
|
|
||||||
p.Δx[:] = 0
|
|
||||||
return p
|
|
||||||
end
|
|
||||||
|
|
||||||
state(x) = x
|
|
||||||
|
|
||||||
Base.size(p::Param) = size(p.x)
|
|
||||||
Base.size(p::Param, n) = size(p.x, n)
|
|
||||||
|
|
||||||
function Base.show(io::IO, p::Param)
|
|
||||||
print(io, "Param", size(p.x))
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.copy!(xs, p::Param) = copy!(xs, p.x)
|
|
||||||
Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
|
||||||
|
|
||||||
# Anonymous models
|
|
||||||
|
|
||||||
struct Capacitor
|
|
||||||
graph::IVertex{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
|
||||||
|
|
||||||
graph(cap::Capacitor) = cap.graph
|
|
41
src/params.jl
Normal file
41
src/params.jl
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
A `Param` object stores a parameter array along with its gradient.
|
||||||
|
When converting to backends like TensorFlow, identical `Param`s will
|
||||||
|
result in identical variable objects.
|
||||||
|
"""
|
||||||
|
struct Param{T}
|
||||||
|
x::T
|
||||||
|
Δx::T
|
||||||
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
param(x::T) => ::Param{T}
|
||||||
|
|
||||||
|
Convenience method for creating a `Param` object for a given array.
|
||||||
|
"""
|
||||||
|
param(x) = Param(x, zero(x))
|
||||||
|
|
||||||
|
state(p::Param) = p.x
|
||||||
|
|
||||||
|
"""
|
||||||
|
update!(p::Param)
|
||||||
|
|
||||||
|
Apply the accumulated updates to the value of the parameter.
|
||||||
|
"""
|
||||||
|
function update!(p::Param, η)
|
||||||
|
p.x .-= p.Δx .* η
|
||||||
|
p.Δx[:] = 0
|
||||||
|
return p
|
||||||
|
end
|
||||||
|
|
||||||
|
state(x) = x
|
||||||
|
|
||||||
|
Base.size(p::Param) = size(p.x)
|
||||||
|
Base.size(p::Param, n) = size(p.x, n)
|
||||||
|
|
||||||
|
function Base.show(io::IO, p::Param)
|
||||||
|
print(io, "Param", size(p.x))
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.copy!(xs, p::Param) = copy!(xs, p.x)
|
||||||
|
Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
Loading…
Reference in New Issue
Block a user