new struct syntax
This commit is contained in:
parent
0cf99dbbdf
commit
2adc3cd18e
@ -1,6 +1,6 @@
|
||||
using Flux: runrawbatched
|
||||
|
||||
type AlterParam
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
store
|
||||
@ -15,7 +15,7 @@ function copyargs!(as, bs)
|
||||
end
|
||||
end
|
||||
|
||||
type Graph
|
||||
struct Graph
|
||||
output
|
||||
params::Dict{Symbol,Any}
|
||||
stacks::Dict{Any,Any}
|
||||
@ -31,7 +31,7 @@ end
|
||||
|
||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||
|
||||
type Exec
|
||||
struct Exec
|
||||
graph::Graph
|
||||
exec::mx.Executor
|
||||
args::Dict{Symbol,MXArray}
|
||||
@ -84,7 +84,7 @@ end
|
||||
|
||||
# TODO: if `last` changes, update params appropriately
|
||||
|
||||
type Model <: Flux.Model
|
||||
mutable struct Model <: Flux.Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
execs::Dict{Tuple,Exec}
|
||||
@ -119,7 +119,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
type SoftmaxOutput
|
||||
struct SoftmaxOutput
|
||||
name::Symbol
|
||||
end
|
||||
|
||||
|
@ -5,7 +5,7 @@ using MXNet
|
||||
|
||||
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
|
||||
|
||||
immutable MXArray{N}
|
||||
struct MXArray{N}
|
||||
data::mx.NDArray
|
||||
scratch::Array{Float32,N}
|
||||
end
|
||||
|
@ -1,4 +1,4 @@
|
||||
type Model
|
||||
struct Model
|
||||
model::Any
|
||||
session::Session
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
|
@ -1,6 +1,6 @@
|
||||
# TODO: refactor, some of this is more general than just the TF backend
|
||||
|
||||
type SeqModel
|
||||
struct SeqModel
|
||||
m::Model
|
||||
state::Any
|
||||
end
|
||||
|
@ -5,7 +5,7 @@ import Flux: accuracy, rebatch, convertel
|
||||
|
||||
export tf
|
||||
|
||||
type Op
|
||||
struct Op
|
||||
f
|
||||
shape
|
||||
end
|
||||
|
@ -1,6 +1,6 @@
|
||||
export unroll, unroll1
|
||||
|
||||
type Offset
|
||||
struct Offset
|
||||
name::Symbol
|
||||
n::Int
|
||||
default::Nullable{Param}
|
||||
|
@ -2,7 +2,7 @@ using DataFlow.Interpreter
|
||||
|
||||
export @shapes
|
||||
|
||||
type Hint
|
||||
struct Hint
|
||||
typ
|
||||
end
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
export Batch, batchone
|
||||
|
||||
immutable Batch{T,S} <: AbstractVector{T}
|
||||
struct Batch{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
end
|
||||
|
||||
|
@ -2,7 +2,7 @@ import Base: eltype, size, getindex, setindex!, convert
|
||||
|
||||
export CatMat, rawbatch
|
||||
|
||||
immutable CatMat{T,S} <: AbstractVector{T}
|
||||
struct CatMat{T,S} <: AbstractVector{T}
|
||||
data::S
|
||||
end
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
export seq, Seq, BatchSeq
|
||||
|
||||
immutable Seq{T,S} <: AbstractVector{T}
|
||||
struct Seq{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
end
|
||||
|
||||
|
@ -11,12 +11,12 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
|
||||
|
||||
# Shim for kicking off shape inference
|
||||
|
||||
type ShapeError <: Exception
|
||||
struct ShapeError <: Exception
|
||||
layer
|
||||
shape
|
||||
end
|
||||
|
||||
type Input{N} <: Model
|
||||
struct Input{N} <: Model
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
@ -27,7 +27,7 @@ back!(::Input, Δ, x) = Δ
|
||||
|
||||
# Initialise placeholder
|
||||
|
||||
type Init{F}
|
||||
struct Init{F}
|
||||
f::F
|
||||
end
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
export Conv2D, MaxPool, AvgPool, Reshape
|
||||
|
||||
type Conv2D <: Model
|
||||
struct Conv2D <: Model
|
||||
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
||||
stride::Dims{2}
|
||||
end
|
||||
@ -16,7 +16,7 @@ shape(c::Conv2D, in::Dims{3}) =
|
||||
|
||||
for Pool in :[MaxPool, AvgPool].args
|
||||
@eval begin
|
||||
type $Pool <: Model
|
||||
struct $Pool <: Model
|
||||
size::Dims{2}
|
||||
stride::Dims{2}
|
||||
end
|
||||
@ -34,7 +34,7 @@ for Pool in :[MaxPool, AvgPool].args
|
||||
end
|
||||
end
|
||||
|
||||
immutable Reshape{N}
|
||||
struct Reshape{N}
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
|
@ -54,7 +54,7 @@ A `Param` object stores a parameter array along with an accumulated delta to
|
||||
that array. When converting to backends like TensorFlow, identical `Param`s will
|
||||
result in identical variable objects, making model reuse trivial.
|
||||
"""
|
||||
type Param{T}
|
||||
struct Param{T}
|
||||
x::T
|
||||
Δx::T
|
||||
end
|
||||
@ -107,7 +107,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
||||
|
||||
export Capacitor
|
||||
|
||||
type Capacitor <: Model
|
||||
struct Capacitor <: Model
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user