track -> param

This commit is contained in:
Mike J Innes 2017-09-07 15:13:04 -04:00
parent cbaf661145
commit f55b8cd20e
4 changed files with 7 additions and 7 deletions

View File

@ -32,7 +32,7 @@ struct Dense{F,S,T}
end
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
Dense(σ, track(init(out, in)), track(init(out)))
Dense(σ, param(init(out, in)), param(init(out)))
Optimise.children(d::Dense) = (d.W, d.b)

View File

@ -70,7 +70,7 @@ struct RNNCell{D,V}
end
RNNCell(in::Integer, out::Integer; init = initn) =
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
RNNCell(Dense(in+out, out, init = initn), param(initn(out)))
function (m::RNNCell)(h, x)
h = m.d(combine(x, h))
@ -100,7 +100,7 @@ end
function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
Dense(in+out, out, tanh, init = initn),
track(initn(out)), track(initn(out)))
param(initn(out)), param(initn(out)))
cell.forget.b.data .= 1
return cell
end

View File

@ -2,7 +2,7 @@ module Tracker
using Base: RefValue
export TrackedArray, track, back!
export TrackedArray, param, back!
data(x) = x
istracked(x) = false
@ -37,7 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
track(xs) = TrackedArray(AbstractFloat.(xs))
param(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad[]
@ -58,7 +58,7 @@ Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}}
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "track(")
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else

View File

@ -1,5 +1,5 @@
function gradient(f, xs::AbstractArray...)
xs = track.(xs)
xs = param.(xs)
back!(f(xs...))
grad.(xs)
end