tweaks
This commit is contained in:
parent
36baa7ec2c
commit
d21c313ea7
@ -8,6 +8,7 @@ VecShape{T} = Shape{T,1}
|
|||||||
MatShape{T} = Shape{T,2}
|
MatShape{T} = Shape{T,2}
|
||||||
|
|
||||||
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
Shape{T}(dims::Vararg{Integer,N}) where {T,N} = Shape{T,N}(dims)
|
||||||
|
Shape{T}(dims::NTuple{N,Integer}) where {T,N} = Shape{T,N}(dims)
|
||||||
|
|
||||||
Base.size(s::Shape) = s.dims
|
Base.size(s::Shape) = s.dims
|
||||||
Base.size(s::Shape, n) = s.dims[n]
|
Base.size(s::Shape, n) = s.dims[n]
|
||||||
@ -22,7 +23,7 @@ function Base.show(io::IO, s::Shape{T}) where T
|
|||||||
print(io, ")")
|
print(io, ")")
|
||||||
end
|
end
|
||||||
|
|
||||||
shape(x) = typeof(x)
|
shape(x) = x
|
||||||
shape(x::Shape) = x
|
shape(x::Shape) = x
|
||||||
shape(x::Tuple) = shape.(x)
|
shape(x::Tuple) = shape.(x)
|
||||||
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
shape(x::AbstractArray) = Shape{eltype(x)}(size(x)...)
|
||||||
|
@ -53,9 +53,12 @@ struct Compiled{F,T<:Tuple}
|
|||||||
params::T
|
params::T
|
||||||
end
|
end
|
||||||
|
|
||||||
(c::Compiled)(args...) =
|
# TODO when we support derivatives
|
||||||
Tracker.track(Tracker.Call(c, args...),
|
# (c::Compiled)(args...) =
|
||||||
c.func(Tracker.data.(c.params), args...))
|
# Tracker.track(Tracker.Call(c, args...),
|
||||||
|
# c.func(Tracker.data.(c.params), args...))
|
||||||
|
|
||||||
|
(c::Compiled)(args...) = c.func(Tracker.data.(c.params), Tracker.data.(args)...)
|
||||||
|
|
||||||
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
Base.show(io::IO, c::Compiled) = print(io, "Compiled(", c.model, ")")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user