tracked array restructure
This commit is contained in:
parent
c8d4844da4
commit
5b6a5667ed
|
@ -18,7 +18,7 @@ loss(x, y) # ~ 3
|
|||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: param, back!, data, grad
|
||||
using Flux.Tracker
|
||||
|
||||
W = param(W)
|
||||
b = param(b)
|
||||
|
@ -31,10 +31,10 @@ back!(l)
|
|||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||
|
||||
```julia
|
||||
grad(W)
|
||||
W.grad
|
||||
|
||||
# Update the parameter
|
||||
W.data .-= 0.1grad(W)
|
||||
W.data .-= 0.1(W.grad)
|
||||
|
||||
loss(x, y) # ~ 2.5
|
||||
```
|
||||
|
|
|
@ -17,14 +17,11 @@ back!(l)
|
|||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||
|
||||
```julia
|
||||
using Flux.Tracker: data, grad
|
||||
|
||||
function update()
|
||||
η = 0.1 # Learning Rate
|
||||
for p in (W, b)
|
||||
x, Δ = data(p), grad(p)
|
||||
x .-= η .* Δ # Apply the update
|
||||
Δ .= 0 # Clear the gradient
|
||||
p.data .-= η .* p.grad # Apply the update
|
||||
p.grad .= 0 # Clear the gradient
|
||||
end
|
||||
end
|
||||
```
|
||||
|
|
|
@ -16,6 +16,6 @@ include("train.jl")
|
|||
|
||||
using Flux.Tracker: TrackedArray
|
||||
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||
|
||||
end
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
module Tracker
|
||||
|
||||
using Base: RefValue
|
||||
|
||||
export TrackedArray, param, back!
|
||||
|
||||
data(x) = x
|
||||
|
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
ref::RefValue{UInt32}
|
||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
data::A
|
||||
grad::RefValue{A}
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
|
@ -28,19 +28,20 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
|
|||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad[]
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
|
@ -73,8 +74,6 @@ include("numeric.jl")
|
|||
|
||||
import NNlib.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) =
|
||||
TrackedArray(xs.f, adapt(T, xs.data),
|
||||
RefValue(adapt(T, grad(xs))))
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
|
||||
end
|
||||
|
|
|
@ -3,11 +3,11 @@ scan(x) = nothing
|
|||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::TrackedArray)
|
||||
ref = x.ref[] += 1
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
else
|
||||
isassigned(x.grad) || (x.grad[] = zeros(x.data))
|
||||
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
|
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
|||
back(::Call{Void}, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
ref = x.ref[] -= 1
|
||||
if isassigned(x.grad)
|
||||
x.grad[] .+= Δ
|
||||
ref == 0 && back(x.f, x.grad[])
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
ref == 0 && back(x.f, x.grad)
|
||||
else
|
||||
ref == 0 && back(x.f, Δ)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue