diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 06237ece..0218ceaa 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -49,7 +49,7 @@ interp(ctx, c::Conv2D, x) = interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = haskey(ctx[:params], p.value) ? ctx[:params][p.value] : - (ctx[:params][p.value] = Variable(p.value.x)) + (ctx[:params][p.value] = Variable(convertel(Float32, p.value.x))) interp(ctx, p::Constant) = p.value diff --git a/src/utils.jl b/src/utils.jl index c685fd4e..3a36b861 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,7 +2,7 @@ export AArray const AArray = AbstractArray -initn(dims...) = randn(Float32, dims...)/10 +initn(dims...) = randn(dims...)/10 function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1) i = 0