diff --git a/examples/MNIST.jl b/examples/MNIST.jl index 9ab749a3..fa3576af 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -14,7 +14,7 @@ m = @Chain( model = mxnet(m) # An example prediction pre-training -model(data[1][1]) +model(unsqueeze(data[1][1])) Flux.train!(model, train, test, η = 1e-4) diff --git a/src/cost.jl b/src/cost.jl index b99a9c40..b08675d5 100644 --- a/src/cost.jl +++ b/src/cost.jl @@ -1,8 +1,5 @@ export mse, mse! -function mse!(Δ, pred, target) - map!(-, Δ, pred, target) - sumabs2(Δ)/2 -end +mse(ŷ, y) = sumabs2(ŷ .- y)/2 -mse(pred, target) = mse!(similar(pred), pred, target) +back!(::typeof(mse), Δ, ŷ, y) = Δ*(ŷ .- y) diff --git a/src/data.jl b/src/data.jl index abfe2f47..63110cc9 100644 --- a/src/data.jl +++ b/src/data.jl @@ -1,5 +1,8 @@ export onehot, onecold, chunk, partition, batches, sequences +mapt(f, x) = f(x) +mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) + convertel(T::Type, xs::AbstractArray) = convert.(T, xs) convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs @@ -21,7 +24,11 @@ onehot(label, labels) = onehot(Int, label, labels) The inverse of `onehot`; takes an output prediction vector and a list of possible values, and produces the appropriate value. """ -onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))] +onecold(y::AbstractVector, labels = 1:length(y)) = + labels[findfirst(y, maximum(y))] + +onecold(y::AbstractMatrix, l...) = + squeeze(mapslices(y -> onecold(y, l...), y, 2), 2) using Iterators import Iterators: Partition, partition diff --git a/src/dims/utils.jl b/src/dims/utils.jl index 7cc00476..4d87a35a 100644 --- a/src/dims/utils.jl +++ b/src/dims/utils.jl @@ -1,3 +1,5 @@ +export unsqueeze + unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) Base.squeeze(xs) = squeeze(xs, 1) diff --git a/src/utils.jl b/src/utils.jl index bae5736c..e9d0eddc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,22 +2,22 @@ export AArray const AArray = AbstractArray -mapt(f, x) = f(x) -mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) - initn(dims...) = randn(dims...)/100 -function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1) +tobatch(xs::Batch) = rawbatch(xs) +tobatch(xs) = unsqueeze(xs) + +function train!(m, train, test = []; epoch = 1, η = 0.1) i = 0 - Δ = zeros(length(train[1][2])) for _ in 1:epoch @progress for (x, y) in train + x, y = tobatch.((x, y)) i += 1 - pred = m(x) - any(isnan, pred) && error("NaN") - err = mse!(Δ, pred, y) + ŷ = m(x) + any(isnan, ŷ) && error("NaN") + Δ = back!(mse, 1, ŷ, y) back!(m, Δ, x) - i % batch == 0 && update!(m, η) + update!(m, η) i % 1000 == 0 && @show accuracy(m, test) end end @@ -27,7 +27,8 @@ end function accuracy(m, data) correct = 0 for (x, y) in data - onecold(m(x)) == onecold(y) && (correct += 1) + x, y = tobatch.((x, y)) + correct += sum(onecold(m(x)) .== onecold(y)) end return correct/length(data) end