training tweaks
This commit is contained in:
parent
1526b13691
commit
e7f26370d7
@ -1,6 +1,6 @@
|
||||
module Optimise
|
||||
|
||||
using ..Tracker: TrackedArray, data, grad, back!
|
||||
using ..Tracker: TrackedArray, grad, back!
|
||||
|
||||
export sgd, update!, params, train!
|
||||
|
||||
|
@ -8,5 +8,4 @@ function train!(m, data, opt; epoch = 1)
|
||||
update!(opt)
|
||||
end
|
||||
end
|
||||
return m
|
||||
end
|
||||
|
@ -30,7 +30,7 @@ onecold(y::AbstractVector, labels = 1:length(y)) =
|
||||
labels[findfirst(y, maximum(y))]
|
||||
|
||||
onecold(y::AbstractMatrix, l...) =
|
||||
squeeze(mapslices(y -> onecold(y, l...), y, 2), 2)
|
||||
squeeze(mapslices(y -> onecold(y, l...), y, 1), 1)
|
||||
|
||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user