diff --git a/src/Flux.jl b/src/Flux.jl index ba9a6327..45e3044e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -4,11 +4,11 @@ module Flux # Zero Flux Given -using Juno +using Juno, Requires using Lazy: @forward export Chain, Dense, RNN, LSTM, - SGD, params + SGD, params, mapparams using NNlib export σ, relu, softmax diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 37c1b787..0ae5f8fa 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -56,7 +56,10 @@ Dense(in::Integer, out::Integer, σ = identity; init = initn) = treelike(Dense) -(a::Dense)(x) = a.σ.(a.W*x .+ a.b) +function (a::Dense)(x) + W, b, σ = a.W, a.b, a.σ + σ.(W*x .+ b) +end function Base.show(io::IO, l::Dense) print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1)) diff --git a/src/onehot.jl b/src/onehot.jl index 1e147397..48b7ccf5 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -9,8 +9,8 @@ Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix] -struct OneHotMatrix <: AbstractMatrix{Bool} - data::Vector{OneHotVector} +struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool} + data::A end Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data)) @@ -21,6 +21,10 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)] Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) +@require CuArrays begin + CuArrays.cu(xs::OneHotMatrix) = OneHotMatrix(CuArrays.cu(xs.data)) +end + onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls]) diff --git a/src/tree.jl b/src/tree.jl index 438685d5..bd6b2d73 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -8,13 +8,18 @@ function treelike(T, fs = fieldnames(T)) end end -using DataFlow: OSet +# TODO: prewalk/postwalk with correct caching +# This is only correct in general for idempotent functions -params(ps, p::AbstractArray) = push!(ps, p) -params(ps, m) = foreach(m -> params(ps, m), children(m)) +mapparams(f, x::AbstractArray) = f(x) +mapparams(f, x) = mapchildren(x -> mapparams(f, x), x) + +forparams(f, x) = (mapparams(x -> (f(x); x), x); return) + +using DataFlow: OSet function params(m) ps = OSet() - params(ps, m) + forparams(p -> push!(ps, p), m) return collect(ps) end