adapt updates
This commit is contained in:
parent
a88b7528bf
commit
a57f66e58a
@ -26,7 +26,6 @@ end
|
|||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
|
||||||
|
|
||||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||||
|
|
||||||
|
@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs
|
|||||||
|
|
||||||
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs)
|
||||||
|
|
||||||
import Adapt.adapt
|
import Adapt: adapt, adapt_structure
|
||||||
|
|
||||||
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
|
||||||
import .CuArrays: CuArray, cudaconvert
|
import .CuArrays: CuArray, cudaconvert
|
||||||
|
@ -108,8 +108,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs))
|
|||||||
param(x::TrackedReal) = track(identity, x)
|
param(x::TrackedReal) = track(identity, x)
|
||||||
param(x::TrackedArray) = track(identity, x)
|
param(x::TrackedArray) = track(identity, x)
|
||||||
|
|
||||||
import Adapt.adapt
|
import Adapt: adapt, adapt_structure
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user