model abstract is pretty useless
This commit is contained in:
parent
4685d2e672
commit
837173d65b
@ -81,7 +81,7 @@ end
|
||||
|
||||
# TODO: if `last` changes, update params appropriately
|
||||
|
||||
mutable struct Model <: Flux.Model
|
||||
mutable struct Model
|
||||
model::Any
|
||||
execs::Dict{Tuple,Exec}
|
||||
graph::Graph
|
||||
|
@ -1,4 +1,5 @@
|
||||
using Base: @get!
|
||||
using Flux: Reshape, MaxPool, flatten
|
||||
using DataFlow: constant, Split
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: stack
|
||||
@ -86,8 +87,9 @@ function tograph(model, args...; variables = false)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) =
|
||||
tograph(m, args...; variables = true)[3]
|
||||
# TODO: replace this
|
||||
# TensorFlow.Tensor(m::Flux.Model, args...) =
|
||||
# tograph(m, args...; variables = true)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
|
@ -28,7 +28,7 @@ end
|
||||
function build_type(T, params)
|
||||
@esc T
|
||||
ex = quote
|
||||
type $T <: Model
|
||||
type $T
|
||||
$(params...)
|
||||
end
|
||||
end
|
||||
|
@ -2,7 +2,7 @@ export unroll, unroll1
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
mutable struct Stateful
|
||||
model
|
||||
states::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
|
@ -67,7 +67,7 @@ end
|
||||
|
||||
export Input
|
||||
|
||||
struct Input{N} <: Model
|
||||
struct Input{N}
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
|
@ -29,7 +29,7 @@ end
|
||||
|
||||
# SeqModel wrapper layer for convenience
|
||||
|
||||
struct SeqModel <: Model
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
@ -1,6 +1,6 @@
|
||||
export Chain, @Chain
|
||||
|
||||
type Chain <: Model
|
||||
type Chain
|
||||
layers::Vector{Any}
|
||||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
@ -1,6 +1,6 @@
|
||||
export Conv2D
|
||||
|
||||
struct Conv2D <: Model
|
||||
struct Conv2D
|
||||
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
||||
stride::Dims{2}
|
||||
end
|
||||
@ -15,7 +15,7 @@ infer(c::Conv2D, in::Dims{4}) =
|
||||
|
||||
for Pool in :[MaxPool, AvgPool].args
|
||||
@eval begin
|
||||
struct $Pool <: Model
|
||||
struct $Pool
|
||||
size::Dims{2}
|
||||
stride::Dims{2}
|
||||
end
|
||||
|
27
src/model.jl
27
src/model.jl
@ -1,31 +1,14 @@
|
||||
# Basic model API
|
||||
|
||||
"""
|
||||
(m::Model)(X...) => Y
|
||||
|
||||
A "model" is a function with state. For example, a logistic regression is the
|
||||
function
|
||||
|
||||
x -> σ(x * W + b)
|
||||
|
||||
where `W` and `b` are a trainable matrix and vector of weights repectively. The
|
||||
`Model` abstract type is used loosely; in general the concept of a model is
|
||||
closer to a protocol, and models don't need to inherit from this type. Normal
|
||||
Julia functions are models with 0 parameters, for example.
|
||||
"""
|
||||
abstract type Model end
|
||||
|
||||
"""
|
||||
back!(m::Model, ΔY, X...) => ΔX
|
||||
back!(model, ΔY, X...) => ΔX
|
||||
|
||||
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
|
||||
gradients of any parameters. Returns the gradient of the input `X`. Gradients
|
||||
may be arrays or tuples of arrays (for multiple inputs/outputs).
|
||||
"""
|
||||
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
|
||||
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
|
||||
|
||||
"""
|
||||
update!(m::Model, η) => m
|
||||
update!(model, η) => m
|
||||
|
||||
Update the parameters of the model `m` using the accumulated gradients from
|
||||
`back!`, using the learning rate `η`.
|
||||
@ -33,7 +16,7 @@ Update the parameters of the model `m` using the accumulated gradients from
|
||||
update!(m, η) = m
|
||||
|
||||
"""
|
||||
graph(m::Model) => ::IVertex{Any} | nothing
|
||||
graph(model) => ::IVertex{Any} | nothing
|
||||
|
||||
Returns the graph representation of the model, if any. Most models are built
|
||||
from lower-level components and can simply implement this method to get most of
|
||||
@ -91,7 +74,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
||||
|
||||
# Anonymous models
|
||||
|
||||
struct Capacitor <: Model
|
||||
struct Capacitor
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user