From a57f66e58a7179bbc0cd37b44e3c410efd2393fd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 14 Nov 2018 15:34:45 +0000 Subject: [PATCH] adapt updates --- src/layers/basic.jl | 1 - src/onehot.jl | 4 ++-- src/tracker/Tracker.jl | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0c2d3715..308d7b00 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -26,7 +26,6 @@ end children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) -adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...) (c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x) diff --git a/src/onehot.jl b/src/onehot.jl index 5d902c77..b6cee63d 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -28,9 +28,9 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs batch(xs::AbstractArray{<:OneHotVector}) = OneHotMatrix(length(first(xs)), xs) -import Adapt.adapt +import Adapt: adapt, adapt_structure -adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) +adapt_structure(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) @init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin import .CuArrays: CuArray, cudaconvert diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 94f9a94c..14201297 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -108,8 +108,8 @@ param(xs::AbstractArray) = TrackedArray(float.(xs)) param(x::TrackedReal) = track(identity, x) param(x::TrackedArray) = track(identity, x) -import Adapt.adapt +import Adapt: adapt, adapt_structure -adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) +adapt_structure(T, xs::TrackedArray) = param(adapt(T, data(xs))) end