model abstract is pretty useless

This commit is contained in:
Mike J Innes 2017-06-05 16:08:23 +01:00
parent 4685d2e672
commit 837173d65b
9 changed files with 17 additions and 32 deletions

View File

@ -81,7 +81,7 @@ end
# TODO: if `last` changes, update params appropriately # TODO: if `last` changes, update params appropriately
mutable struct Model <: Flux.Model mutable struct Model
model::Any model::Any
execs::Dict{Tuple,Exec} execs::Dict{Tuple,Exec}
graph::Graph graph::Graph

View File

@ -1,4 +1,5 @@
using Base: @get! using Base: @get!
using Flux: Reshape, MaxPool, flatten
using DataFlow: constant, Split using DataFlow: constant, Split
using DataFlow.Interpreter using DataFlow.Interpreter
using DataFlow.Interpreter: stack using DataFlow.Interpreter: stack
@ -86,8 +87,9 @@ function tograph(model, args...; variables = false)
return ctx[:params], ctx[:stacks], out return ctx[:params], ctx[:stacks], out
end end
TensorFlow.Tensor(m::Flux.Model, args...) = # TODO: replace this
tograph(m, args...; variables = true)[3] # TensorFlow.Tensor(m::Flux.Model, args...) =
# tograph(m, args...; variables = true)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -28,7 +28,7 @@ end
function build_type(T, params) function build_type(T, params)
@esc T @esc T
ex = quote ex = quote
type $T <: Model type $T
$(params...) $(params...)
end end
end end

View File

@ -2,7 +2,7 @@ export unroll, unroll1
# Stateful Models # Stateful Models
mutable struct Stateful <: Model mutable struct Stateful
model model
states::Vector{Any} states::Vector{Any}
istate::Vector{Any} istate::Vector{Any}

View File

@ -67,7 +67,7 @@ end
export Input export Input
struct Input{N} <: Model struct Input{N}
dims::Dims{N} dims::Dims{N}
end end

View File

@ -29,7 +29,7 @@ end
# SeqModel wrapper layer for convenience # SeqModel wrapper layer for convenience
struct SeqModel <: Model struct SeqModel
model model
steps::Int steps::Int
end end

View File

@ -1,6 +1,6 @@
export Chain, @Chain export Chain, @Chain
type Chain <: Model type Chain
layers::Vector{Any} layers::Vector{Any}
Chain(xs...) = new([xs...]) Chain(xs...) = new([xs...])
end end

View File

@ -1,6 +1,6 @@
export Conv2D export Conv2D
struct Conv2D <: Model struct Conv2D
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans] filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
stride::Dims{2} stride::Dims{2}
end end
@ -15,7 +15,7 @@ infer(c::Conv2D, in::Dims{4}) =
for Pool in :[MaxPool, AvgPool].args for Pool in :[MaxPool, AvgPool].args
@eval begin @eval begin
struct $Pool <: Model struct $Pool
size::Dims{2} size::Dims{2}
stride::Dims{2} stride::Dims{2}
end end

View File

@ -1,31 +1,14 @@
# Basic model API
""" """
(m::Model)(X...) => Y back!(model, ΔY, X...) => ΔX
A "model" is a function with state. For example, a logistic regression is the
function
x -> σ(x * W + b)
where `W` and `b` are a trainable matrix and vector of weights repectively. The
`Model` abstract type is used loosely; in general the concept of a model is
closer to a protocol, and models don't need to inherit from this type. Normal
Julia functions are models with 0 parameters, for example.
"""
abstract type Model end
"""
back!(m::Model, ΔY, X...) => ΔX
Backpropagate the gradient `ΔY` through the model `m`, accumulating the Backpropagate the gradient `ΔY` through the model `m`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs). may be arrays or tuples of arrays (for multiple inputs/outputs).
""" """
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))") back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
""" """
update!(m::Model, η) => m update!(model, η) => m
Update the parameters of the model `m` using the accumulated gradients from Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`. `back!`, using the learning rate `η`.
@ -33,7 +16,7 @@ Update the parameters of the model `m` using the accumulated gradients from
update!(m, η) = m update!(m, η) = m
""" """
graph(m::Model) => ::IVertex{Any} | nothing graph(model) => ::IVertex{Any} | nothing
Returns the graph representation of the model, if any. Most models are built Returns the graph representation of the model, if any. Most models are built
from lower-level components and can simply implement this method to get most of from lower-level components and can simply implement this method to get most of
@ -91,7 +74,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
# Anonymous models # Anonymous models
struct Capacitor <: Model struct Capacitor
graph::IVertex{Any} graph::IVertex{Any}
end end