From ddcd576a74484a67320a721a0913612f58c6dad1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 5 Jun 2017 16:09:06 +0100 Subject: [PATCH] give up and use AbstractArray --- src/backend/mxnet/graph.jl | 2 +- src/backend/tensorflow/graph.jl | 2 +- src/compiler/code.jl | 2 +- src/utils.jl | 2 -- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index ce42cd54..4107b368 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -90,7 +90,7 @@ end register(ctx::Context, node) = node -function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam}) +function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam}) id = gensym() ctx[:params][id] = p return mx.Variable(id) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index a3862309..365345f3 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -62,7 +62,7 @@ end interp(ctx, c::Conv2D, x) = nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID") -param(ctx, p::Flux.Param{<:AArray}) = +param(ctx, p::Flux.Param{<:AbstractArray}) = haskey(ctx[:params], p) ? ctx[:params][p] : (ctx[:params][p] = diff --git a/src/compiler/code.jl b/src/compiler/code.jl index c47be6c4..fb403897 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -34,7 +34,7 @@ function build_type(T, params) end if any(x->isexpr(x, Symbol), params) push!(ex.args, - :($T($(map(x->isexpr(x, Symbol) ? :($x::AArray) : x, params)...)) = + :($T($(map(x->isexpr(x, Symbol) ? :($x::AbstractArray) : x, params)...)) = $T($(map(x->isexpr(x, Symbol) ? :(param($x)) : namify(x), params)...)))) end ex diff --git a/src/utils.jl b/src/utils.jl index d5ae18e8..885b2ccf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,8 +2,6 @@ call(f, xs...) = f(xs...) # Arrays -const AArray = AbstractArray - initn(dims...) = randn(dims...)/100 unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))