tracked array restructure
This commit is contained in:
parent
c8d4844da4
commit
5b6a5667ed
@ -18,7 +18,7 @@ loss(x, y) # ~ 3
|
|||||||
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
|
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: param, back!, data, grad
|
using Flux.Tracker
|
||||||
|
|
||||||
W = param(W)
|
W = param(W)
|
||||||
b = param(b)
|
b = param(b)
|
||||||
@ -31,10 +31,10 @@ back!(l)
|
|||||||
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then calculates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model.
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
grad(W)
|
W.grad
|
||||||
|
|
||||||
# Update the parameter
|
# Update the parameter
|
||||||
W.data .-= 0.1grad(W)
|
W.data .-= 0.1(W.grad)
|
||||||
|
|
||||||
loss(x, y) # ~ 2.5
|
loss(x, y) # ~ 2.5
|
||||||
```
|
```
|
||||||
|
@ -17,14 +17,11 @@ back!(l)
|
|||||||
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
using Flux.Tracker: data, grad
|
|
||||||
|
|
||||||
function update()
|
function update()
|
||||||
η = 0.1 # Learning Rate
|
η = 0.1 # Learning Rate
|
||||||
for p in (W, b)
|
for p in (W, b)
|
||||||
x, Δ = data(p), grad(p)
|
p.data .-= η .* p.grad # Apply the update
|
||||||
x .-= η .* Δ # Apply the update
|
p.grad .= 0 # Clear the gradient
|
||||||
Δ .= 0 # Clear the gradient
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
@ -16,6 +16,6 @@ include("train.jl")
|
|||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
module Tracker
|
module Tracker
|
||||||
|
|
||||||
using Base: RefValue
|
|
||||||
|
|
||||||
export TrackedArray, param, back!
|
export TrackedArray, param, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
@ -16,11 +14,13 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||||||
|
|
||||||
(c::Call)() = c.func(data.(c.args)...)
|
(c::Call)() = c.func(data.(c.args)...)
|
||||||
|
|
||||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||||
ref::RefValue{UInt32}
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
data::A
|
data::A
|
||||||
grad::RefValue{A}
|
grad::A
|
||||||
|
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||||
|
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||||
@ -28,19 +28,20 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
|
|||||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||||
|
|
||||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
|
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||||
|
|
||||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||||
|
|
||||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||||
istracked(x::TrackedArray) = true
|
istracked(x::TrackedArray) = true
|
||||||
data(x::TrackedArray) = x.data
|
data(x::TrackedArray) = x.data
|
||||||
grad(x::TrackedArray) = x.grad[]
|
grad(x::TrackedArray) = x.grad
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
@ -73,8 +74,6 @@ include("numeric.jl")
|
|||||||
|
|
||||||
import NNlib.adapt
|
import NNlib.adapt
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) =
|
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||||
TrackedArray(xs.f, adapt(T, xs.data),
|
|
||||||
RefValue(adapt(T, grad(xs))))
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -3,11 +3,11 @@ scan(x) = nothing
|
|||||||
scan(c::Call) = foreach(scan, c.args)
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
|
|
||||||
function scan(x::TrackedArray)
|
function scan(x::TrackedArray)
|
||||||
ref = x.ref[] += 1
|
ref = x.ref += 1
|
||||||
if ref == 1
|
if ref == 1
|
||||||
scan(x.f)
|
scan(x.f)
|
||||||
else
|
else
|
||||||
isassigned(x.grad) || (x.grad[] = zeros(x.data))
|
isdefined(x, :grad) || (x.grad = zeros(x.data))
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -16,10 +16,10 @@ back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
|||||||
back(::Call{Void}, Δ) = nothing
|
back(::Call{Void}, Δ) = nothing
|
||||||
|
|
||||||
function back(x::TrackedArray, Δ)
|
function back(x::TrackedArray, Δ)
|
||||||
ref = x.ref[] -= 1
|
ref = x.ref -= 1
|
||||||
if isassigned(x.grad)
|
if isdefined(x, :grad)
|
||||||
x.grad[] .+= Δ
|
x.grad .+= Δ
|
||||||
ref == 0 && back(x.f, x.grad[])
|
ref == 0 && back(x.f, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back(x.f, Δ)
|
ref == 0 && back(x.f, Δ)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user