remove old torch-esque code
This commit is contained in:
parent
676a10e78b
commit
9986a1c163
|
@ -7,7 +7,6 @@ using MacroTools, Lazy, Flow
|
|||
export Model, back!, update!
|
||||
|
||||
abstract Model
|
||||
abstract Activation <: Model
|
||||
|
||||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||
update!(m::Model, η) = m
|
||||
|
|
|
@ -1,28 +1 @@
|
|||
export Sigmoid
|
||||
|
||||
σ(x) = 1/(1+exp(-x))
|
||||
σ′(x) = σ(x)*(1-σ(x))
|
||||
|
||||
∇₁(::typeof(σ)) = σ′
|
||||
|
||||
type Sigmoid <: Activation
|
||||
in::Vector{Float32}
|
||||
out::Vector{Float32}
|
||||
∇in::Vector{Float32}
|
||||
end
|
||||
|
||||
Sigmoid(size::Integer) = Sigmoid(zeros(size), zeros(size), zeros(size))
|
||||
|
||||
function (l::Sigmoid)(x)
|
||||
l.in = x
|
||||
map!(σ, l.out, x)
|
||||
end
|
||||
|
||||
function back!(l::Sigmoid, ∇)
|
||||
map!(σ′, l.∇in, l.in)
|
||||
map!(*, l.∇in, l.∇in, ∇)
|
||||
end
|
||||
|
||||
shape(l::Sigmoid) = length(l.in)
|
||||
|
||||
Sigmoid() = Init(in -> Sigmoid(in[1]))
|
||||
abstract Activation <: Model
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
export mse, mse!
|
||||
|
||||
function mse!(∇, pred, target)
|
||||
map!(-, ∇, pred, target)
|
||||
sumabs2(∇)/2
|
||||
end
|
||||
|
||||
mse(pred, target) = mse(similar(pred), pred, target)
|
|
@ -1,41 +0,0 @@
|
|||
export Dense
|
||||
|
||||
type Dense <: Model
|
||||
W::Matrix{Float32}
|
||||
b::Vector{Float32}
|
||||
∇W::Matrix{Float32}
|
||||
∇b::Vector{Float32}
|
||||
|
||||
in::Vector{Float32}
|
||||
out::Vector{Float32}
|
||||
∇in::Vector{Float32}
|
||||
end
|
||||
|
||||
Dense(in::Integer, out::Integer) =
|
||||
Dense(randn(out, in), randn(out),
|
||||
zeros(out, in), zeros(out),
|
||||
zeros(in), zeros(out), zeros(in))
|
||||
|
||||
Dense(out::Integer) = Init(in -> Dense(in[1], out))
|
||||
|
||||
function (l::Dense)(x)
|
||||
l.in = x
|
||||
A_mul_B!(l.out, l.W, x)
|
||||
map!(+, l.out, l.out, l.b)
|
||||
end
|
||||
|
||||
function back!(l::Dense, ∇)
|
||||
map!(+, l.∇b, l.∇b, ∇)
|
||||
# l.∇W += ∇ * l.in'
|
||||
BLAS.gemm!('N', 'T', eltype(∇)(1), ∇, l.in, eltype(∇)(1), l.∇W)
|
||||
At_mul_B!(l.∇in, l.W, ∇)
|
||||
end
|
||||
|
||||
function update!(l::Dense, η)
|
||||
map!((x, ∇x) -> x - η*∇x, l.W, l.W, l.∇W)
|
||||
map!((x, ∇x) -> x - η*∇x, l.b, l.b, l.∇b)
|
||||
fill!(l.∇W, 0)
|
||||
fill!(l.∇b, 0)
|
||||
end
|
||||
|
||||
shape(d::Dense) = size(d.b)
|
|
@ -10,7 +10,7 @@ type Input{N} <: Model
|
|||
dims::Dims{N}
|
||||
end
|
||||
|
||||
Input(i) = Input(dims(i))
|
||||
Input(i...) = Input(dims(i...))
|
||||
|
||||
(::Input)(x) = x
|
||||
back!(::Input, ∇) = ∇
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
export Sequence
|
||||
|
||||
type Sequence
|
||||
type Sequence <: Model
|
||||
layers::Vector{Model}
|
||||
end
|
||||
|
||||
|
|
23
src/utils.jl
23
src/utils.jl
|
@ -2,26 +2,3 @@ export onehot, onecold
|
|||
|
||||
onehot(label, labels) = [i == label for i in labels]
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
||||
function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
∇ = zeros(length(train[1][2]))
|
||||
for _ in 1:epoch
|
||||
for (x, y) in shuffle!(train)
|
||||
i += 1
|
||||
err = mse!(∇, m(x), y)
|
||||
back!(m, ∇)
|
||||
i % batch == 0 && update!(m, η/batch)
|
||||
end
|
||||
@show accuracy(m, test)
|
||||
end
|
||||
return m
|
||||
end
|
||||
|
||||
function accuracy(m::Model, data)
|
||||
correct = 0
|
||||
for (x, y) in data
|
||||
onecold(m(x)) == onecold(y) && (correct += 1)
|
||||
end
|
||||
return correct/length(data)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue