beginnings of gpu support

This commit is contained in:
Mike J Innes 2017-09-27 21:58:34 +01:00
parent 120a6db2bb
commit a60a754d68
4 changed files with 21 additions and 9 deletions

View File

@ -4,11 +4,11 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Juno using Juno, Requires
using Lazy: @forward using Lazy: @forward
export Chain, Dense, RNN, LSTM, export Chain, Dense, RNN, LSTM,
SGD, params SGD, params, mapparams
using NNlib using NNlib
export σ, relu, softmax export σ, relu, softmax

View File

@ -56,7 +56,10 @@ Dense(in::Integer, out::Integer, σ = identity; init = initn) =
treelike(Dense) treelike(Dense)
(a::Dense)(x) = a.σ.(a.W*x .+ a.b) function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
end
function Base.show(io::IO, l::Dense) function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))

View File

@ -9,8 +9,8 @@ Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
struct OneHotMatrix <: AbstractMatrix{Bool} struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
data::Vector{OneHotVector} data::A
end end
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data)) Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
@ -21,6 +21,10 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
@require CuArrays begin
CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data))
end
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls]) onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])

View File

@ -8,13 +8,18 @@ function treelike(T, fs = fieldnames(T))
end end
end end
using DataFlow: OSet # TODO: prewalk/postwalk with correct caching
# This is only correct in general for idempotent functions
params(ps, p::AbstractArray) = push!(ps, p) mapparams(f, x::AbstractArray) = f(x)
params(ps, m) = foreach(m -> params(ps, m), children(m)) mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
using DataFlow: OSet
function params(m) function params(m)
ps = OSet() ps = OSet()
params(ps, m) forparams(p -> push!(ps, p), m)
return collect(ps) return collect(ps)
end end