tracked array restructure

This commit is contained in:
Mike J Innes 2017-10-18 22:54:58 +01:00
parent c8d4844da4
commit 5b6a5667ed
5 changed files with 24 additions and 28 deletions

View File

@ -18,7 +18,7 @@ loss(x, y) # ~ 3
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*. To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
```julia ```julia
using Flux.Tracker: param, back!, data, grad using Flux.Tracker
W = param(W) W = param(W)
b = param(b) b = param(b)
@ -31,10 +31,10 @@ back!(l)
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. `loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
```julia ```julia
grad(W) W.grad
# Update the parameter # Update the parameter
W.data .-= 0.1grad(W) W.data .-= 0.1(W.grad)
loss(x, y) # ~ 2.5 loss(x, y) # ~ 2.5
``` ```

View File

@ -17,14 +17,11 @@ back!(l)
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
```julia ```julia
using Flux.Tracker: data, grad
function update() function update()
η = 0.1 # Learning Rate η = 0.1 # Learning Rate
for p in (W, b) for p in (W, b)
x, Δ = data(p), grad(p) p.data .-= η .* p.grad # Apply the update
x .-= η .* Δ # Apply the update p.grad .= 0 # Clear the gradient
Δ .= 0 # Clear the gradient
end end
end end
``` ```

View File

@ -16,6 +16,6 @@ include("train.jl")
using Flux.Tracker: TrackedArray using Flux.Tracker: TrackedArray
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[]) Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
end end

View File

@ -1,7 +1,5 @@
module Tracker module Tracker
using Base: RefValue
export TrackedArray, param, back! export TrackedArray, param, back!
data(x) = x data(x) = x
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...) (c::Call)() = c.func(data.(c.args)...)
struct TrackedArray{T,N,A} <: AbstractArray{T,N} mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::RefValue{UInt32} ref::UInt32
f::Call f::Call
data::A data::A
grad::RefValue{A} grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
end end
TrackedScalar{T,A} = TrackedArray{T,0,A} TrackedScalar{T,A} = TrackedArray{T,0,A}
@ -28,19 +28,20 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A} TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray = TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ) TrackedArray{eltype(A),ndims(A),A}(c, x)
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}()) TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x))) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad[] grad(x::TrackedArray) = x.grad
# Fallthrough methods # Fallthrough methods
@ -73,8 +74,6 @@ include("numeric.jl")
import NNlib.adapt import NNlib.adapt
adapt(T, xs::TrackedArray) = adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
TrackedArray(xs.f, adapt(T, xs.data),
RefValue(adapt(T, grad(xs))))
end end

View File

@ -3,11 +3,11 @@ scan(x) = nothing
scan(c::Call) = foreach(scan, c.args) scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray) function scan(x::TrackedArray)
ref = x.ref[] += 1 ref = x.ref += 1
if ref == 1 if ref == 1
scan(x.f) scan(x.f)
else else
isassigned(x.grad) || (x.grad[] = zeros(x.data)) isdefined(x, :grad) || (x.grad = zeros(x.data))
end end
return return
end end
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing back(::Call{Void}, Δ) = nothing
function back(x::TrackedArray, Δ) function back(x::TrackedArray, Δ)
ref = x.ref[] -= 1 ref = x.ref -= 1
if isassigned(x.grad) if isdefined(x, :grad)
x.grad[] .+= Δ x.grad .+= Δ
ref == 0 && back(x.f, x.grad[]) ref == 0 && back(x.f, x.grad)
else else
ref == 0 && back(x.f, Δ) ref == 0 && back(x.f, Δ)
end end