track -> param
This commit is contained in:
parent
cbaf661145
commit
f55b8cd20e
@ -32,7 +32,7 @@ struct Dense{F,S,T}
|
|||||||
end
|
end
|
||||||
|
|
||||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||||
Dense(σ, track(init(out, in)), track(init(out)))
|
Dense(σ, param(init(out, in)), param(init(out)))
|
||||||
|
|
||||||
Optimise.children(d::Dense) = (d.W, d.b)
|
Optimise.children(d::Dense) = (d.W, d.b)
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ struct RNNCell{D,V}
|
|||||||
end
|
end
|
||||||
|
|
||||||
RNNCell(in::Integer, out::Integer; init = initn) =
|
RNNCell(in::Integer, out::Integer; init = initn) =
|
||||||
RNNCell(Dense(in+out, out, init = initn), track(initn(out)))
|
RNNCell(Dense(in+out, out, init = initn), param(initn(out)))
|
||||||
|
|
||||||
function (m::RNNCell)(h, x)
|
function (m::RNNCell)(h, x)
|
||||||
h = m.d(combine(x, h))
|
h = m.d(combine(x, h))
|
||||||
@ -100,7 +100,7 @@ end
|
|||||||
function LSTMCell(in, out; init = initn)
|
function LSTMCell(in, out; init = initn)
|
||||||
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
||||||
Dense(in+out, out, tanh, init = initn),
|
Dense(in+out, out, tanh, init = initn),
|
||||||
track(initn(out)), track(initn(out)))
|
param(initn(out)), param(initn(out)))
|
||||||
cell.forget.b.data .= 1
|
cell.forget.b.data .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
@ -2,7 +2,7 @@ module Tracker
|
|||||||
|
|
||||||
using Base: RefValue
|
using Base: RefValue
|
||||||
|
|
||||||
export TrackedArray, track, back!
|
export TrackedArray, param, back!
|
||||||
|
|
||||||
data(x) = x
|
data(x) = x
|
||||||
istracked(x) = false
|
istracked(x) = false
|
||||||
@ -37,7 +37,7 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
|||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
||||||
|
|
||||||
track(xs) = TrackedArray(AbstractFloat.(xs))
|
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||||
istracked(x::TrackedArray) = true
|
istracked(x::TrackedArray) = true
|
||||||
data(x::TrackedArray) = x.data
|
data(x::TrackedArray) = x.data
|
||||||
grad(x::TrackedArray) = x.grad[]
|
grad(x::TrackedArray) = x.grad[]
|
||||||
@ -58,7 +58,7 @@ Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}}
|
|||||||
|
|
||||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||||
if repr
|
if repr
|
||||||
print(io, "track(")
|
print(io, "param(")
|
||||||
Base.showarray(io, data(X), true)
|
Base.showarray(io, data(X), true)
|
||||||
print(io, ")")
|
print(io, ")")
|
||||||
else
|
else
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
function gradient(f, xs::AbstractArray...)
|
function gradient(f, xs::AbstractArray...)
|
||||||
xs = track.(xs)
|
xs = param.(xs)
|
||||||
back!(f(xs...))
|
back!(f(xs...))
|
||||||
grad.(xs)
|
grad.(xs)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user