split out core
This commit is contained in:
parent
9a460e12f2
commit
2717ace397
@ -20,9 +20,12 @@ export @net, unroll, unroll1, @shapes,
|
||||
include("Batches/Batches.jl")
|
||||
using .Batches
|
||||
|
||||
include("core.jl")
|
||||
import .FluxCore: back!, update!, graph
|
||||
|
||||
include("utils.jl")
|
||||
|
||||
include("model.jl")
|
||||
include("params.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
|
@ -25,3 +25,13 @@ function interpmodel(m, args...)
|
||||
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||
@ithrow interp(ctx, m, args...)
|
||||
end
|
||||
|
||||
# Anonymous models
|
||||
|
||||
struct Capacitor
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||
|
||||
graph(cap::Capacitor) = cap.graph
|
||||
|
@ -1,3 +1,8 @@
|
||||
# This code is in a submodule with the intention that it will be split into an
|
||||
# interface package.
|
||||
|
||||
module FluxCore
|
||||
|
||||
"""
|
||||
back!(model, ΔY, X...) => ΔX
|
||||
|
||||
@ -27,57 +32,4 @@ methods as necessary.
|
||||
"""
|
||||
graph(m) = nothing
|
||||
|
||||
# Model parameters
|
||||
|
||||
# TODO: should be AbstractArray?
|
||||
"""
|
||||
A `Param` object stores a parameter array along with its gradient.
|
||||
When converting to backends like TensorFlow, identical `Param`s will
|
||||
result in identical variable objects.
|
||||
"""
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δx::T
|
||||
end
|
||||
|
||||
"""
|
||||
param(x::T) => ::Param{T}
|
||||
|
||||
Convenience method for creating a `Param` object for a given array.
|
||||
"""
|
||||
param(x) = Param(x, zero(x))
|
||||
|
||||
state(p::Param) = p.x
|
||||
|
||||
"""
|
||||
update!(p::Param)
|
||||
|
||||
Apply the accumulated updates to the value of the parameter.
|
||||
"""
|
||||
function update!(p::Param, η)
|
||||
p.x .-= p.Δx .* η
|
||||
p.Δx[:] = 0
|
||||
return p
|
||||
end
|
||||
|
||||
state(x) = x
|
||||
|
||||
Base.size(p::Param) = size(p.x)
|
||||
Base.size(p::Param, n) = size(p.x, n)
|
||||
|
||||
function Base.show(io::IO, p::Param)
|
||||
print(io, "Param", size(p.x))
|
||||
end
|
||||
|
||||
Base.copy!(xs, p::Param) = copy!(xs, p.x)
|
||||
Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
||||
|
||||
# Anonymous models
|
||||
|
||||
struct Capacitor
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||
|
||||
graph(cap::Capacitor) = cap.graph
|
41
src/params.jl
Normal file
41
src/params.jl
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
A `Param` object stores a parameter array along with its gradient.
|
||||
When converting to backends like TensorFlow, identical `Param`s will
|
||||
result in identical variable objects.
|
||||
"""
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δx::T
|
||||
end
|
||||
|
||||
"""
|
||||
param(x::T) => ::Param{T}
|
||||
|
||||
Convenience method for creating a `Param` object for a given array.
|
||||
"""
|
||||
param(x) = Param(x, zero(x))
|
||||
|
||||
state(p::Param) = p.x
|
||||
|
||||
"""
|
||||
update!(p::Param)
|
||||
|
||||
Apply the accumulated updates to the value of the parameter.
|
||||
"""
|
||||
function update!(p::Param, η)
|
||||
p.x .-= p.Δx .* η
|
||||
p.Δx[:] = 0
|
||||
return p
|
||||
end
|
||||
|
||||
state(x) = x
|
||||
|
||||
Base.size(p::Param) = size(p.x)
|
||||
Base.size(p::Param, n) = size(p.x, n)
|
||||
|
||||
function Base.show(io::IO, p::Param)
|
||||
print(io, "Param", size(p.x))
|
||||
end
|
||||
|
||||
Base.copy!(xs, p::Param) = copy!(xs, p.x)
|
||||
Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
Loading…
Reference in New Issue
Block a user