beginnings of gpu support
This commit is contained in:
parent
120a6db2bb
commit
a60a754d68
@ -4,11 +4,11 @@ module Flux
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
using Juno
|
||||
using Juno, Requires
|
||||
using Lazy: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM,
|
||||
SGD, params
|
||||
SGD, params, mapparams
|
||||
|
||||
using NNlib
|
||||
export σ, relu, softmax
|
||||
|
@ -56,7 +56,10 @@ Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||
|
||||
treelike(Dense)
|
||||
|
||||
(a::Dense)(x) = a.σ.(a.W*x .+ a.b)
|
||||
function (a::Dense)(x)
|
||||
W, b, σ = a.W, a.b, a.σ
|
||||
σ.(W*x .+ b)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, l::Dense)
|
||||
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
||||
|
@ -9,8 +9,8 @@ Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
||||
|
||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||
|
||||
struct OneHotMatrix <: AbstractMatrix{Bool}
|
||||
data::Vector{OneHotVector}
|
||||
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||
data::A
|
||||
end
|
||||
|
||||
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
|
||||
@ -21,6 +21,10 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||
|
||||
@require CuArrays begin
|
||||
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data))
|
||||
end
|
||||
|
||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
|
||||
|
||||
|
13
src/tree.jl
13
src/tree.jl
@ -8,13 +8,18 @@ function treelike(T, fs = fieldnames(T))
|
||||
end
|
||||
end
|
||||
|
||||
using DataFlow: OSet
|
||||
# TODO: prewalk/postwalk with correct caching
|
||||
# This is only correct in general for idempotent functions
|
||||
|
||||
params(ps, p::AbstractArray) = push!(ps, p)
|
||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
||||
mapparams(f, x::AbstractArray) = f(x)
|
||||
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
|
||||
|
||||
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
|
||||
|
||||
using DataFlow: OSet
|
||||
|
||||
function params(m)
|
||||
ps = OSet()
|
||||
params(ps, m)
|
||||
forparams(p -> push!(ps, p), m)
|
||||
return collect(ps)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user