track -> param
This commit is contained in:
parent
cbaf661145
commit
f55b8cd20e
|
@ -32,7 +32,7 @@ struct Dense{F,S,T}
|
|||
end
|
||||
|
||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
Dense(σ, track(init(out, in)), track(init(out)))
|
||||
Dense(σ, param(init(out, in)), param(init(out)))
|
||||
|
||||
Optimise.children(d::Dense) = (d.W, d.b)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ struct RNNCell{D,V}
|
|||
end
|
||||
|
||||
RNNCell(in::Integer, out::Integer; init = initn) =
|
||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
||||
RNNCell(Dense(in+out, out, init = initn), param(initn(out)))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
h = m.d(combine(x, h))
|
||||
|
@ -100,7 +100,7 @@ end
|
|||
function LSTMCell(in, out; init = initn)
|
||||
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
||||
Dense(in+out, out, tanh, init = initn),
|
||||
track(initn(out)), track(initn(out)))
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.forget.b.data .= 1
|
||||
return cell
|
||||
end
|
||||
|
|
|
@ -2,7 +2,7 @@ module Tracker
|
|||
|
||||
using Base: RefValue
|
||||
|
||||
export TrackedArray, track, back!
|
||||
export TrackedArray, param, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
|
@ -37,7 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
|||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
||||
|
||||
track(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad[]
|
||||
|
@ -58,7 +58,7 @@ Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}}
|
|||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "track(")
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
function gradient(f, xs::AbstractArray...)
|
||||
xs = track.(xs)
|
||||
xs = param.(xs)
|
||||
back!(f(xs...))
|
||||
grad.(xs)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue