make tf model running easier
This commit is contained in:
parent
770b74fe5f
commit
8db503eafa
|
@ -53,7 +53,7 @@ function runmodel(m::Model, args...)
|
|||
end
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
runmodel(m, args...)
|
||||
runmodel(m, map(x -> convertel(Float32, x), args)...)
|
||||
end
|
||||
|
||||
function (m::Model)(args...)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
module TF
|
||||
|
||||
using ..Flux, DataFlow, TensorFlow, Juno
|
||||
import Flux: accuracy, rebatch
|
||||
import Flux: accuracy, rebatch, convertel
|
||||
|
||||
export tf
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
export onehot, onecold, chunk, partition, batches, sequences
|
||||
|
||||
convertel(T::Type, xs::AbstractArray) = map(x->convert(T, x), xs)
|
||||
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||
|
||||
"""
|
||||
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
xs = rand(Float32, 20)
|
||||
xs = rand(20)
|
||||
|
||||
d = Affine(20, 10)
|
||||
|
||||
|
|
Loading…
Reference in New Issue