make tf model running easier
This commit is contained in:
parent
770b74fe5f
commit
8db503eafa
@ -53,7 +53,7 @@ function runmodel(m::Model, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(args::Batch...)
|
function (m::Model)(args::Batch...)
|
||||||
runmodel(m, args...)
|
runmodel(m, map(x -> convertel(Float32, x), args)...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(args...)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, DataFlow, TensorFlow, Juno
|
using ..Flux, DataFlow, TensorFlow, Juno
|
||||||
import Flux: accuracy, rebatch
|
import Flux: accuracy, rebatch, convertel
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
export onehot, onecold, chunk, partition, batches, sequences
|
export onehot, onecold, chunk, partition, batches, sequences
|
||||||
|
|
||||||
|
convertel(T::Type, xs::AbstractArray) = map(x->convert(T, x), xs)
|
||||||
|
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
xs = rand(Float32, 20)
|
xs = rand(20)
|
||||||
|
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user