new struct syntax
This commit is contained in:
parent
0cf99dbbdf
commit
2adc3cd18e
@ -1,6 +1,6 @@
|
|||||||
using Flux: runrawbatched
|
using Flux: runrawbatched
|
||||||
|
|
||||||
type AlterParam
|
struct AlterParam
|
||||||
param
|
param
|
||||||
load
|
load
|
||||||
store
|
store
|
||||||
@ -15,7 +15,7 @@ function copyargs!(as, bs)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
type Graph
|
struct Graph
|
||||||
output
|
output
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
stacks::Dict{Any,Any}
|
stacks::Dict{Any,Any}
|
||||||
@ -31,7 +31,7 @@ end
|
|||||||
|
|
||||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||||
|
|
||||||
type Exec
|
struct Exec
|
||||||
graph::Graph
|
graph::Graph
|
||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
args::Dict{Symbol,MXArray}
|
args::Dict{Symbol,MXArray}
|
||||||
@ -84,7 +84,7 @@ end
|
|||||||
|
|
||||||
# TODO: if `last` changes, update params appropriately
|
# TODO: if `last` changes, update params appropriately
|
||||||
|
|
||||||
type Model <: Flux.Model
|
mutable struct Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
graph::Graph
|
graph::Graph
|
||||||
execs::Dict{Tuple,Exec}
|
execs::Dict{Tuple,Exec}
|
||||||
@ -119,7 +119,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
|||||||
|
|
||||||
# MX FeedForward interface
|
# MX FeedForward interface
|
||||||
|
|
||||||
type SoftmaxOutput
|
struct SoftmaxOutput
|
||||||
name::Symbol
|
name::Symbol
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ using MXNet
|
|||||||
|
|
||||||
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
|
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
|
||||||
|
|
||||||
immutable MXArray{N}
|
struct MXArray{N}
|
||||||
data::mx.NDArray
|
data::mx.NDArray
|
||||||
scratch::Array{Float32,N}
|
scratch::Array{Float32,N}
|
||||||
end
|
end
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
type Model
|
struct Model
|
||||||
model::Any
|
model::Any
|
||||||
session::Session
|
session::Session
|
||||||
params::Dict{Flux.Param,Tensor}
|
params::Dict{Flux.Param,Tensor}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# TODO: refactor, some of this is more general than just the TF backend
|
# TODO: refactor, some of this is more general than just the TF backend
|
||||||
|
|
||||||
type SeqModel
|
struct SeqModel
|
||||||
m::Model
|
m::Model
|
||||||
state::Any
|
state::Any
|
||||||
end
|
end
|
||||||
|
@ -5,7 +5,7 @@ import Flux: accuracy, rebatch, convertel
|
|||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
type Op
|
struct Op
|
||||||
f
|
f
|
||||||
shape
|
shape
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export unroll, unroll1
|
export unroll, unroll1
|
||||||
|
|
||||||
type Offset
|
struct Offset
|
||||||
name::Symbol
|
name::Symbol
|
||||||
n::Int
|
n::Int
|
||||||
default::Nullable{Param}
|
default::Nullable{Param}
|
||||||
|
@ -2,7 +2,7 @@ using DataFlow.Interpreter
|
|||||||
|
|
||||||
export @shapes
|
export @shapes
|
||||||
|
|
||||||
type Hint
|
struct Hint
|
||||||
typ
|
typ
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export Batch, batchone
|
export Batch, batchone
|
||||||
|
|
||||||
immutable Batch{T,S} <: AbstractVector{T}
|
struct Batch{T,S} <: AbstractVector{T}
|
||||||
data::CatMat{T,S}
|
data::CatMat{T,S}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import Base: eltype, size, getindex, setindex!, convert
|
|||||||
|
|
||||||
export CatMat, rawbatch
|
export CatMat, rawbatch
|
||||||
|
|
||||||
immutable CatMat{T,S} <: AbstractVector{T}
|
struct CatMat{T,S} <: AbstractVector{T}
|
||||||
data::S
|
data::S
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export seq, Seq, BatchSeq
|
export seq, Seq, BatchSeq
|
||||||
|
|
||||||
immutable Seq{T,S} <: AbstractVector{T}
|
struct Seq{T,S} <: AbstractVector{T}
|
||||||
data::CatMat{T,S}
|
data::CatMat{T,S}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -11,12 +11,12 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
|
|||||||
|
|
||||||
# Shim for kicking off shape inference
|
# Shim for kicking off shape inference
|
||||||
|
|
||||||
type ShapeError <: Exception
|
struct ShapeError <: Exception
|
||||||
layer
|
layer
|
||||||
shape
|
shape
|
||||||
end
|
end
|
||||||
|
|
||||||
type Input{N} <: Model
|
struct Input{N} <: Model
|
||||||
dims::Dims{N}
|
dims::Dims{N}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ back!(::Input, Δ, x) = Δ
|
|||||||
|
|
||||||
# Initialise placeholder
|
# Initialise placeholder
|
||||||
|
|
||||||
type Init{F}
|
struct Init{F}
|
||||||
f::F
|
f::F
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export Conv2D, MaxPool, AvgPool, Reshape
|
export Conv2D, MaxPool, AvgPool, Reshape
|
||||||
|
|
||||||
type Conv2D <: Model
|
struct Conv2D <: Model
|
||||||
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
filter::Param{Array{Float64,4}} # [height, width, inchans, outchans]
|
||||||
stride::Dims{2}
|
stride::Dims{2}
|
||||||
end
|
end
|
||||||
@ -16,7 +16,7 @@ shape(c::Conv2D, in::Dims{3}) =
|
|||||||
|
|
||||||
for Pool in :[MaxPool, AvgPool].args
|
for Pool in :[MaxPool, AvgPool].args
|
||||||
@eval begin
|
@eval begin
|
||||||
type $Pool <: Model
|
struct $Pool <: Model
|
||||||
size::Dims{2}
|
size::Dims{2}
|
||||||
stride::Dims{2}
|
stride::Dims{2}
|
||||||
end
|
end
|
||||||
@ -34,7 +34,7 @@ for Pool in :[MaxPool, AvgPool].args
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
immutable Reshape{N}
|
struct Reshape{N}
|
||||||
dims::Dims{N}
|
dims::Dims{N}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ A `Param` object stores a parameter array along with an accumulated delta to
|
|||||||
that array. When converting to backends like TensorFlow, identical `Param`s will
|
that array. When converting to backends like TensorFlow, identical `Param`s will
|
||||||
result in identical variable objects, making model reuse trivial.
|
result in identical variable objects, making model reuse trivial.
|
||||||
"""
|
"""
|
||||||
type Param{T}
|
struct Param{T}
|
||||||
x::T
|
x::T
|
||||||
Δx::T
|
Δx::T
|
||||||
end
|
end
|
||||||
@ -107,7 +107,7 @@ Base.copy!(p::Param, xs) = copy!(p.x, xs)
|
|||||||
|
|
||||||
export Capacitor
|
export Capacitor
|
||||||
|
|
||||||
type Capacitor <: Model
|
struct Capacitor <: Model
|
||||||
graph::IVertex{Any}
|
graph::IVertex{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user