This commit is contained in:
Mike Innes 2018-03-06 19:58:47 +00:00
parent 36baa7ec2c
commit d21c313ea7
2 changed files with 8 additions and 4 deletions

View File

@ -8,6 +8,7 @@ VecShape{T} = Shape{T,1}
MatShape{T} = Shape{T,2}
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
Base.size(s::Shape) = s.dims
Base.size(s::Shape, n) = s.dims[n]
@ -22,7 +23,7 @@ function Base.show(io::IO, s::Shape{T}) where T
print(io, ")")
end
shape(x) = typeof(x)
shape(x) = x
shape(x::Shape) = x
shape(x::Tuple) = shape.(x)
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)

View File

@ -53,9 +53,12 @@ struct Compiled{F,T<:Tuple}
params::T
end
(c::Compiled)(args...) =
Tracker.track(Tracker.Call(c, args...),
c.func(Tracker.data.(c.params), args...))
# TODO when we support derivatives
# (c::Compiled)(args...) =
# Tracker.track(Tracker.Call(c, args...),
# c.func(Tracker.data.(c.params), args...))
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")