new struct syntax

This commit is contained in:
Mike J Innes 2017-03-14 17:56:03 +00:00
parent 0cf99dbbdf
commit 2adc3cd18e
13 changed files with 22 additions and 22 deletions

View File

@ -1,6 +1,6 @@
using Flux: runrawbatched using Flux: runrawbatched
type AlterParam struct AlterParam
param param
load load
store store
@ -15,7 +15,7 @@ function copyargs!(as, bs)
end end
end end
type Graph struct Graph
output output
params::Dict{Symbol,Any} params::Dict{Symbol,Any}
stacks::Dict{Any,Any} stacks::Dict{Any,Any}
@ -31,7 +31,7 @@ end
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d) ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
type Exec struct Exec
graph::Graph graph::Graph
exec::mx.Executor exec::mx.Executor
args::Dict{Symbol,MXArray} args::Dict{Symbol,MXArray}
@ -84,7 +84,7 @@ end
# TODO: if `last` changes, update params appropriately # TODO: if `last` changes, update params appropriately
type Model <: Flux.Model mutable struct Model <: Flux.Model
model::Any model::Any
graph::Graph graph::Graph
execs::Dict{Tuple,Exec} execs::Dict{Tuple,Exec}
@ -119,7 +119,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
# MX FeedForward interface # MX FeedForward interface
type SoftmaxOutput struct SoftmaxOutput
name::Symbol name::Symbol
end end

View File

@ -5,7 +5,7 @@ using MXNet
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1) reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
immutable MXArray{N} struct MXArray{N}
data::mx.NDArray data::mx.NDArray
scratch::Array{Float32,N} scratch::Array{Float32,N}
end end

View File

@ -1,4 +1,4 @@
type Model struct Model
model::Any model::Any
session::Session session::Session
params::Dict{Flux.Param,Tensor} params::Dict{Flux.Param,Tensor}

View File

@ -1,6 +1,6 @@
# TODO: refactor, some of this is more general than just the TF backend # TODO: refactor, some of this is more general than just the TF backend
type SeqModel struct SeqModel
m::Model m::Model
state::Any state::Any
end end

View File

@ -5,7 +5,7 @@ import Flux: accuracy, rebatch, convertel
export tf export tf
type Op struct Op
f f
shape shape
end end

View File

@ -1,6 +1,6 @@
export unroll, unroll1 export unroll, unroll1
type Offset struct Offset
name::Symbol name::Symbol
n::Int n::Int
default::Nullable{Param} default::Nullable{Param}

View File

@ -2,7 +2,7 @@ using DataFlow.Interpreter
export @shapes export @shapes
type Hint struct Hint
typ typ
end end

View File

@ -1,6 +1,6 @@
export Batch, batchone export Batch, batchone
immutable Batch{T,S} <: AbstractVector{T} struct Batch{T,S} <: AbstractVector{T}
data::CatMat{T,S} data::CatMat{T,S}
end end

View File

@ -2,7 +2,7 @@ import Base: eltype, size, getindex, setindex!, convert
export CatMat, rawbatch export CatMat, rawbatch
immutable CatMat{T,S} <: AbstractVector{T} struct CatMat{T,S} <: AbstractVector{T}
data::S data::S
end end

View File

@ -1,6 +1,6 @@
export seq, Seq, BatchSeq export seq, Seq, BatchSeq
immutable Seq{T,S} <: AbstractVector{T} struct Seq{T,S} <: AbstractVector{T}
data::CatMat{T,S} data::CatMat{T,S}
end end

View File

@ -11,12 +11,12 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
# Shim for kicking off shape inference # Shim for kicking off shape inference
type ShapeError <: Exception struct ShapeError <: Exception
layer layer
shape shape
end end
type Input{N} <: Model struct Input{N} <: Model
dims::Dims{N} dims::Dims{N}
end end
@ -27,7 +27,7 @@ back!(::Input, Δ, x) = Δ
# Initialise placeholder # Initialise placeholder
type Init{F} struct Init{F}
f::F f::F
end end

View File

@ -1,6 +1,6 @@
export Conv2D, MaxPool, AvgPool, Reshape export Conv2D, MaxPool, AvgPool, Reshape
type Conv2D <: Model struct Conv2D <: Model
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans] filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
stride::Dims{2} stride::Dims{2}
end end
@ -16,7 +16,7 @@ shape(c::Conv2D, in::Dims{3}) =
for Pool in :[MaxPool, AvgPool].args for Pool in :[MaxPool, AvgPool].args
@eval begin @eval begin
type $Pool <: Model struct $Pool <: Model
size::Dims{2} size::Dims{2}
stride::Dims{2} stride::Dims{2}
end end
@ -34,7 +34,7 @@ for Pool in :[MaxPool, AvgPool].args
end end
end end
immutable Reshape{N} struct Reshape{N}
dims::Dims{N} dims::Dims{N}
end end

View File

@ -54,7 +54,7 @@ A `Param` object stores a parameter array along with an accumulated delta to
that array. When converting to backends like TensorFlow, identical `Param`s will that array. When converting to backends like TensorFlow, identical `Param`s will
result in identical variable objects, making model reuse trivial. result in identical variable objects, making model reuse trivial.
""" """
type Param{T} struct Param{T}
x::T x::T
Δx::T Δx::T
end end
@ -107,7 +107,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
export Capacitor export Capacitor
type Capacitor <: Model struct Capacitor <: Model
graph::IVertex{Any} graph::IVertex{Any}
end end