tracked array restructure

This commit is contained in:
Mike J Innes 2017-10-18 22:54:58 +01:00
parent c8d4844da4
commit 5b6a5667ed
5 changed files with 24 additions and 28 deletions

View File

@ -18,7 +18,7 @@ loss(x, y) # ~ 3
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
```julia
using Flux.Tracker: param, back!, data, grad
using Flux.Tracker
W = param(W)
b = param(b)
@ -31,10 +31,10 @@ back!(l)
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia
grad(W)
W.grad
# Update the parameter
W.data .-= 0.1grad(W)
W.data .-= 0.1(W.grad)
loss(x, y) # ~ 2.5
```

View File

@ -17,14 +17,11 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia
using Flux.Tracker: data, grad
function update()
η = 0.1 # Learning Rate
for p in (W, b)
x, Δ = data(p), grad(p)
x .-= η .* Δ # Apply the update
Δ .= 0 # Clear the gradient
p.data .-= η .* p.grad # Apply the update
p.grad .= 0 # Clear the gradient
end
end
```

View File

@ -16,6 +16,6 @@ include("train.jl")
using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end

View File

@ -1,7 +1,5 @@
module Tracker
using Base: RefValue
export TrackedArray, param, back!
data(x) = x
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::RefValue{UInt32}
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::UInt32
f::Call
data::A
grad::RefValue{A}
grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
end
TrackedScalar{T,A} = TrackedArray{T,0,A}
@ -28,19 +28,20 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x)
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad[]
grad(x::TrackedArray) = x.grad
# Fallthrough methods
@ -73,8 +74,6 @@ include("numeric.jl")
import NNlib.adapt
adapt(T, xs::TrackedArray) =
TrackedArray(xs.f, adapt(T, xs.data),
RefValue(adapt(T, grad(xs))))
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
end

View File

@ -3,11 +3,11 @@ scan(x) = nothing
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
ref = x.ref[] += 1
ref = x.ref += 1
if ref == 1
scan(x.f)
else
isassigned(x.grad) || (x.grad[] = zeros(x.data))
isdefined(x, :grad) || (x.grad = zeros(x.data))
end
return
end
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref[] -= 1
if isassigned(x.grad)
x.grad[] .+= Δ
ref == 0 && back(x.f, x.grad[])
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
ref == 0 && back(x.f, x.grad)
else
ref == 0 && back(x.f, Δ)
end