diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 975f71ce..9fe7e8d2 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -81,7 +81,7 @@ end # TODO: if `last` changes, update params appropriately -mutable struct Model <: Flux.Model +mutable struct Model model::Any execs::Dict{Tuple,Exec} graph::Graph diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 5482ea64..a3862309 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -1,4 +1,5 @@ using Base: @get! +using Flux: Reshape, MaxPool, flatten using DataFlow: constant, Split using DataFlow.Interpreter using DataFlow.Interpreter: stack @@ -86,8 +87,9 @@ function tograph(model, args...; variables = false) return ctx[:params], ctx[:stacks], out end -TensorFlow.Tensor(m::Flux.Model, args...) = - tograph(m, args...; variables = true)[3] +# TODO: replace this +# TensorFlow.Tensor(m::Flux.Model, args...) = +# tograph(m, args...; variables = true)[3] RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index f9192a90..c47be6c4 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -28,7 +28,7 @@ end function build_type(T, params) @esc T ex = quote - type $T <: Model + type $T $(params...) end end diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 576aeb4f..2e025be8 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -2,7 +2,7 @@ export unroll, unroll1 # Stateful Models -mutable struct Stateful <: Model +mutable struct Stateful model states::Vector{Any} istate::Vector{Any} diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index 7de80e61..414170af 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -67,7 +67,7 @@ end export Input -struct Input{N} <: Model +struct Input{N} dims::Dims{N} end diff --git a/src/dims/seq.jl b/src/dims/seq.jl index 384eef38..206527a3 100644 --- a/src/dims/seq.jl +++ b/src/dims/seq.jl @@ -29,7 +29,7 @@ end # SeqModel wrapper layer for convenience -struct SeqModel <: Model +struct SeqModel model steps::Int end diff --git a/src/layers/control.jl b/src/layers/control.jl index 839bb58c..ec12cc3c 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -1,6 +1,6 @@ export Chain, @Chain -type Chain <: Model +type Chain layers::Vector{Any} Chain(xs...) = new([xs...]) end diff --git a/src/layers/shims.jl b/src/layers/shims.jl index 877ac6c5..26c20f0c 100644 --- a/src/layers/shims.jl +++ b/src/layers/shims.jl @@ -1,6 +1,6 @@ export Conv2D -struct Conv2D <: Model +struct Conv2D filter::Param{Array{Float64,4}} # [height, width, inchans, outchans] stride::Dims{2} end @@ -15,7 +15,7 @@ infer(c::Conv2D, in::Dims{4}) = for Pool in :[MaxPool, AvgPool].args @eval begin - struct $Pool <: Model + struct $Pool size::Dims{2} stride::Dims{2} end diff --git a/src/model.jl b/src/model.jl index ffb74689..d87d58be 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,31 +1,14 @@ -# Basic model API - """ - (m::Model)(X...) => Y - -A "model" is a function with state. For example, a logistic regression is the -function - - x -> σ(x * W + b) - -where `W` and `b` are a trainable matrix and vector of weights repectively. The -`Model` abstract type is used loosely; in general the concept of a model is -closer to a protocol, and models don't need to inherit from this type. Normal -Julia functions are models with 0 parameters, for example. -""" -abstract type Model end - -""" - back!(m::Model, ΔY, X...) => ΔX + back!(model, ΔY, X...) => ΔX Backpropagate the gradient `ΔY` through the model `m`, accumulating the gradients of any parameters. Returns the gradient of the input `X`. Gradients may be arrays or tuples of arrays (for multiple inputs/outputs). """ -back!(m::Model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))") +back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))") """ - update!(m::Model, η) => m + update!(model, η) => m Update the parameters of the model `m` using the accumulated gradients from `back!`, using the learning rate `η`. @@ -33,7 +16,7 @@ Update the parameters of the model `m` using the accumulated gradients from update!(m, η) = m """ - graph(m::Model) => ::IVertex{Any} | nothing + graph(model) => ::IVertex{Any} | nothing Returns the graph representation of the model, if any. Most models are built from lower-level components and can simply implement this method to get most of @@ -91,7 +74,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs) # Anonymous models -struct Capacitor <: Model +struct Capacitor graph::IVertex{Any} end