Flux.jl/src/tracker/Tracker.jl

101 lines
2.8 KiB
Julia
Raw Normal View History

2017-08-19 10:06:19 +00:00
module Tracker
2017-11-09 14:53:26 +00:00
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
2017-08-19 15:20:53 +00:00
2017-08-19 09:14:50 +00:00
data(x) = x
2017-08-19 16:40:07 +00:00
istracked(x) = false
2017-08-19 09:14:50 +00:00
2017-08-18 15:50:27 +00:00
struct Call{F,As<:Tuple}
func::F
args::As
end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
2017-08-19 09:14:50 +00:00
(c::Call)() = c.func(data.(c.args)...)
2017-10-18 21:54:58 +00:00
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::UInt32
2017-08-18 15:50:27 +00:00
f::Call
2017-09-07 01:21:35 +00:00
data::A
2017-10-18 21:54:58 +00:00
grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
2017-08-18 15:50:27 +00:00
end
2017-08-19 10:11:25 +00:00
TrackedScalar{T,A} = TrackedArray{T,0,A}
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
2017-09-05 06:12:53 +00:00
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
2017-08-19 09:14:50 +00:00
2017-10-18 21:54:58 +00:00
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x)
2017-08-19 09:14:50 +00:00
2017-10-18 21:54:58 +00:00
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
2017-08-19 09:14:50 +00:00
2017-08-19 10:11:25 +00:00
TrackedArray(c::Call) = TrackedArray(c, c())
2017-08-19 09:14:50 +00:00
2017-10-18 21:54:58 +00:00
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
2017-08-19 09:14:50 +00:00
2017-11-07 19:34:27 +00:00
isleaf(x::TrackedArray) = x.f == Call(nothing)
2017-11-30 13:51:31 +00:00
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
2017-10-12 08:56:23 +00:00
param(xs::Real) = param(fill(xs))
2017-08-19 16:40:07 +00:00
istracked(x::TrackedArray) = true
2017-09-07 01:21:35 +00:00
data(x::TrackedArray) = x.data
2017-10-18 21:54:58 +00:00
grad(x::TrackedArray) = x.grad
2017-08-19 16:40:07 +00:00
2017-08-19 15:02:19 +00:00
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
2017-08-19 10:11:25 +00:00
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
2017-08-19 09:14:50 +00:00
end
2017-08-19 10:00:55 +00:00
2017-08-19 15:02:19 +00:00
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
2017-10-12 08:31:38 +00:00
# TODO decide if keeping both data and value. The problem is TrackedScalar
2017-10-26 11:06:29 +00:00
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
2017-10-25 00:35:27 +00:00
2017-10-26 11:06:29 +00:00
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
2017-10-25 00:35:27 +00:00
2017-10-26 11:06:29 +00:00
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
2017-10-12 08:31:38 +00:00
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
2017-10-23 09:41:08 +00:00
2017-09-03 06:12:54 +00:00
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
2017-08-19 10:11:25 +00:00
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
2017-08-19 10:00:55 +00:00
if repr
2017-09-07 19:13:04 +00:00
print(io, "param(")
2017-08-19 10:00:55 +00:00
Base.showarray(io, data(X), true)
print(io, ")")
else
2017-08-19 15:38:40 +00:00
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
2017-08-19 10:00:55 +00:00
end
end
2017-08-19 10:06:19 +00:00
2017-10-26 10:15:14 +00:00
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
2017-09-07 01:21:35 +00:00
include("back.jl")
2017-08-19 15:02:19 +00:00
include("lib.jl")
2017-08-23 00:43:45 +00:00
include("numeric.jl")
2017-08-19 15:02:19 +00:00
2017-10-04 17:55:56 +00:00
import NNlib.adapt
2017-08-24 16:00:48 +00:00
2017-10-18 21:54:58 +00:00
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
2017-08-24 16:00:48 +00:00
2017-08-19 10:06:19 +00:00
end