tweaks
This commit is contained in:
parent
36baa7ec2c
commit
d21c313ea7
@ -8,6 +8,7 @@ VecShape{T} = Shape{T,1}
|
||||
MatShape{T} = Shape{T,2}
|
||||
|
||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
||||
|
||||
Base.size(s::Shape) = s.dims
|
||||
Base.size(s::Shape, n) = s.dims[n]
|
||||
@ -22,7 +23,7 @@ function Base.show(io::IO, s::Shape{T}) where T
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
shape(x) = typeof(x)
|
||||
shape(x) = x
|
||||
shape(x::Shape) = x
|
||||
shape(x::Tuple) = shape.(x)
|
||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||
|
@ -53,9 +53,12 @@ struct Compiled{F,T<:Tuple}
|
||||
params::T
|
||||
end
|
||||
|
||||
(c::Compiled)(args...) =
|
||||
Tracker.track(Tracker.Call(c, args...),
|
||||
c.func(Tracker.data.(c.params), args...))
|
||||
# TODO when we support derivatives
|
||||
# (c::Compiled)(args...) =
|
||||
# Tracker.track(Tracker.Call(c, args...),
|
||||
# c.func(Tracker.data.(c.params), args...))
|
||||
|
||||
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
||||
|
||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user