model abstract is pretty useless

This commit is contained in:
Mike J Innes 2017-06-05 16:08:23 +01:00
parent 4685d2e672
commit 837173d65b
9 changed files with 17 additions and 32 deletions

View File

@ -81,7 +81,7 @@ end
# TODO: if `last` changes, update params appropriately
mutable struct Model <: Flux.Model
mutable struct Model
model::Any
execs::Dict{Tuple,Exec}
graph::Graph

View File

@ -1,4 +1,5 @@
using Base: @get!
using Flux: Reshape, MaxPool, flatten
using DataFlow: constant, Split
using DataFlow.Interpreter
using DataFlow.Interpreter: stack
@ -86,8 +87,9 @@ function tograph(model, args...; variables = false)
return ctx[:params], ctx[:stacks], out
end
TensorFlow.Tensor(m::Flux.Model, args...) =
tograph(m, args...; variables = true)[3]
# TODO: replace this
# TensorFlow.Tensor(m::Flux.Model, args...) =
# tograph(m, args...; variables = true)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -28,7 +28,7 @@ end
function build_type(T, params)
@esc T
ex = quote
type $T <: Model
type $T
$(params...)
end
end

View File

@ -2,7 +2,7 @@ export unroll, unroll1
# Stateful Models
mutable struct Stateful <: Model
mutable struct Stateful
model
states::Vector{Any}
istate::Vector{Any}

View File

@ -67,7 +67,7 @@ end
export Input
struct Input{N} <: Model
struct Input{N}
dims::Dims{N}
end

View File

@ -29,7 +29,7 @@ end
# SeqModel wrapper layer for convenience
struct SeqModel <: Model
struct SeqModel
model
steps::Int
end

View File

@ -1,6 +1,6 @@
export Chain, @Chain
type Chain <: Model
type Chain
layers::Vector{Any}
Chain(xs...) = new([xs...])
end

View File

@ -1,6 +1,6 @@
export Conv2D
struct Conv2D <: Model
struct Conv2D
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
stride::Dims{2}
end
@ -15,7 +15,7 @@ infer(c::Conv2D, in::Dims{4}) =
for Pool in :[MaxPool, AvgPool].args
@eval begin
struct $Pool <: Model
struct $Pool
size::Dims{2}
stride::Dims{2}
end

View File

@ -1,31 +1,14 @@
# Basic model API
"""
(m::Model)(X...) => Y
A "model" is a function with state. For example, a logistic regression is the
function
x -> σ(x * W + b)
where `W` and `b` are a trainable matrix and vector of weights repectively. The
`Model` abstract type is used loosely; in general the concept of a model is
closer to a protocol, and models don't need to inherit from this type. Normal
Julia functions are models with 0 parameters, for example.
"""
abstract type Model end
"""
back!(m::Model, ΔY, X...) => ΔX
back!(model, ΔY, X...) => ΔX
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs).
"""
back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
"""
update!(m::Model, η) => m
update!(model, η) => m
Update the parameters of the model `m` using the accumulated gradients from
`back!`, using the learning rate `η`.
@ -33,7 +16,7 @@ Update the parameters of the model `m` using the accumulated gradients from
update!(m, η) = m
"""
graph(m::Model) => ::IVertex{Any} | nothing
graph(model) => ::IVertex{Any} | nothing
Returns the graph representation of the model, if any. Most models are built
from lower-level components and can simply implement this method to get most of
@ -91,7 +74,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
# Anonymous models
struct Capacitor <: Model
struct Capacitor
graph::IVertex{Any}
end