mnist notebook
This commit is contained in:
parent
0e9a20b2a4
commit
22930b701f
|
@ -0,0 +1,719 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Neural Networks in Julia"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Walkthrough"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We start by loading the packages:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using Flux, MNIST, Images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### The Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's start by taking a quick look at the data to see what we're dealing with. We start by loading the it into an appropriate format; a list of pairs, where each pair is an input (the image) and a target output (the one-hot-encoded label)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"const data = [(MNIST.trainfeatures(i), onehot(MNIST.trainlabel(i), 0:9)) for i = 1:60_000]\n",
|
||||
"const train = data[1:50_000]\n",
|
||||
"const test = data[50_001:60_000]\n",
|
||||
"nothing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can easily use Julia's array tools and the Images package to check out the data:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAYAAADG4PRLAAAAAXNSR0IArs4c6QAABOlJREFUeAHtnM0rbl0Yxp3j9W3mIyVJKRHCSIlSDJQyMlFMpJQyIX8BwkTMMGAiZWZsRhlIMvGRlK8Zhiby9RqcybVOnd1qrXvtdT9dp87gt9fe97rP9Tt3q/N49vmVlZX1/fObv5Qm8Ftp32z7TwIUqPyvAgVSoPIElLfPCaRA5Qkob58TSIHKE1DePieQApUnoLx9TiAFKk9AefucQApUnoDy9jmBFKg8AeXtcwIpUHkCytvnBFKg8gSUt88JpEDlCShvnxNIgcoTUN4+J5AClSegvH1OoHKB//nuv6ioCErm5+cD9/f3A7e0tACHhtXVVdjy9vYWOHbgBMZuKKE/CkwIKPZlCozdUEJ/v37Wrd4PHBoagpIdHR3/5KamJliPDW5ubqClzs5O4KenJ+DYgBMYmxHLfijQMrDYbqfA2IxY9mN9Bn5/45H59fUFW5r8+PgI6yYcHh7CpefnZ+DLy0tgW2hsbIRHJicngU2Ynp6GS8vLy8CxAScwNiOW/VCgZWCx3U6BsRmx7Mf6s9Dr62vY4u3tDXh2dhZ4d3cXWBqqqqpgi66uLuAkuL+/T7olqnVOYFQ67JuhQPvMonqCAqPSYd+M9RlYV1dnv4vgEzU1NVDdPHPb2tpg3YS9vT24tL+/Dxw7cAJjN5TQHwUmBBT7MgXGbiihP+vPQhPqeV8uLCyEmj09PcDr6+vAZWVlwEnQ3NwMt5yfnwPHDpzA2A0l9EeBCQHFvkyBsRtK6C/6M3BpaQn+CFNTU8CuYP488vX19Z8lT05OYH1rawv47u4OWBo4gdIJC9enQOGApctToHTCwvWtPwsV7uev8rW1tX9d83nB/B5oUu2+vj64pb6+Htj83uzn5yes+wZOoO9EA9ejwMCB+96OAn0nGrhe9P8ObGhogEhKSkqAbaG8vBweGR4eBt7c3ASurq4GXlxcBM7NzQU+OjoC7u7uBv74+AB2BU6ga4IpP0+BKQtw3Z4CXRNM+fnoz0DXfMz3F+fm5qDkyMgI8MPDA7AJ5nds1tbW4BZz3Xw/8uLiAu53BU6ga4IpP0+BKQtw3Z4CXRNM+fnoPwu1zae9vR0eWVhYAJ6ZmQFOOvPg5h84PT2FS9vb28DmGWh+z7SyshLudwVOoGuCKT9PgSkLcN2eAl0TTPn5jDsDzXfcCwoKIOKrqytgVzg+PoYS7+/vwBUVFcC+gRPoO9HA9SgwcOC+t6NA34kGrpdxZ2BpaSlE2NraCryzswM8Pz8PfHBwAGzC4OAgXBoYGADOyckBlgZOoHTCwvUpUDhg6fIUKJ2wcP2MOwPPzs4gMvN7n729vbBu/rzw5eUF1k0wP8vMzs42bwEeHR0F9g2cQN+JBq5HgYED970dBfpONHC9jPtOTF5eHkS4srICPDY2BuwbNjY2oOTExASw73clOIEQrz6gQH3OoGMKhDj0QcadgaYC892F4uJiuGV8fBzY9t0L8+eB5v/VZv4f47CZB+AEeggxzRIUmGb6HvamQA8hplki48/ANMMNsTcnMETKgntQoGC4IUpTYIiUBfegQMFwQ5SmwBApC+5BgYLhhihNgSFSFtyDAgXDDVGaAkOkLLgHBQqGG6I0BYZIWXAPChQMN0RpCgyRsuAeFCgYbojSFBgiZcE9KFAw3BClKTBEyoJ7UKBguCFKU2CIlAX3+B+3gbSWnpc3pwAAAABJRU5ErkJggg==",
|
||||
"text/plain": [
|
||||
"Gray Images.Image with:\n",
|
||||
" data: 28×28 Array{Float64,2}\n",
|
||||
" properties:\n",
|
||||
" timedim: 0\n",
|
||||
" colorspace: Gray\n",
|
||||
" colordim: 0\n",
|
||||
" spatialorder: y x\n",
|
||||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"convert(Image,reshape(data[31][1],(28,28))'/255)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainlabel(31) |> Int"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAYAAADG4PRLAAAAAXNSR0IArs4c6QAABXpJREFUeAHtncsrdlEUxt0iEkOUGDGTW4mROYoYYMAfIKVIGVAGSrkMGIiJIgplQGSgFCkDxNxMilwmmFD4fPVNntXbe9rf3mdf6lEGv3P2Xmu9z2O12+c950hNSUn5+f3lT6AKpAVaN8v+pwANDPxPgQbSwMAVCLx8diANDFyBwMtnB9LAwBUIvHx2IA0MXIHAy2cH0sDAFQi8fHYgDQxcgcDLZwfSwMAVCLx8diANDFyBwMtnB9LAwBUIvHx2IA0MXIHAy2cHBm5gRuD1K5ff0tICc7Kzs4GjIDU1FYb8/CS/rfbs7AzG393dAesCO1BXQcfzaaBjA3TT00BdBR3PD34NzMrKAglramqAZ2dngauqqoAzMzOBo0B1DTw/P4eQ7e3twPf398CqwA5UVcyz8TTQM0NUy6GBqop5Nv7vpib5RsazgsvLy6GimZkZ4KamJuCoNevx8RHGf3x8AO/u7gK3tbUBFxYWAqenpwNLmJiYgEPj4+PAqsAOVFXMs/E00DNDVMuhgaqKeTbe+zVQ7tuGh4dBws7OTmAJJycncEiuaZubm3BedV82Pz8P8/v6+oCjICNDbyvODoxS2PPzNNBzg6LKo4FRCnl+3rs1sLa2FiQ7PDwEzs/PB357ewNubW0FPj4+BjYNco3e39+HFHKfCCd/IWrfKMdLZgdKRQJjGhiYYbJcGigVCYz1NiH/8WFzcnJg1tLSEnBzczNwXl4e8PPzM3Bvby9w3GseJPuFl5cXOPT6+gpcUFAAbBrYgaYVtRyPBloW3HQ6GmhaUcvxrK+B8p6V7u5u+MhR39/Ja5UXFxcw3zbU19dDyrKyMmAJc3Nz8pAWswO15HM/mQa690CrAhqoJZ/7ybGvgXV1dfAp19bWgKPg4OAAhqysrACrPtsAkw1AT09P0ihyn7i4uJh0vOpJdqCqYp6Np4GeGaJaDg1UVcyz8cbXQPn92Pb2NnzkoqIiYAnyWqbcJ76/v8spTrmysjJp/o2NDTh/c3MDrAvsQF0FHc+ngY4N0E1PA3UVdDxfew1UvYcl6vPu7e3BkMbGRmB5zwmctACDg4OQpbi4GFg+Ez8wMADnTQM70LSiluPRQMuCm05HA00rajme9hrY398PJct7WKLeowKTf2F6ehoOjY6OAtteA+W+VtZze3sL9R0dHQHHDezAuBWOOT4NjFnguMPTwLgVjjm+9rMRX19fUKJc8+SzC/JaoLxH5uHhAeJVV1cDPz09AZsG+d6Z5eVlSNHV1QW8tbUFLK/dwskYgB0Yg6g2Q9JAm2rHkIsGxiCqzZDa+8DV1VWoV94j8vn5CeenpqaAc3Nzga+uroBtr3lynyefwV9fX4f6FhYWgG0DO9C24obz0UDDgtoORwNtK244n/Y+UO7jop5pl/XLe2Dke1d2dnbkFKM8OTkJ8eR7aGT+jo4OGO8a2IGuHdDMTwM1BXQ9nQa6dkAzv/YaGJW/oaEBhpyengJLSEvDv6nv7285RIlV48n/61BaWqqUz/ZgVMt2dubTVoAGakvoNgANdKu/dnbta6FRFVxeXsKQsbEx4IqKCmD5jLx8b4zq84ByDZXfV8p7WOT/mYDiPAR2oIemqJREA1XU8nAsDfTQFJWSYt8HqhSTaKy81jo0NATDSkpKgCXINfX6+hqGjIyMAPv2/CEUlwDYgQlECekQDQzJrQS10sAEooR0yPs1MCQxXdTKDnShusGcNNCgmC5C0UAXqhvMSQMNiukiFA10obrBnDTQoJguQtFAF6obzEkDDYrpIhQNdKG6wZw00KCYLkLRQBeqG8xJAw2K6SIUDXShusGcNNCgmC5C0UAXqhvMSQMNiukiFA10obrBnDTQoJguQtFAF6obzPkHUT/96zaIKDgAAAAASUVORK5CYII=",
|
||||
"text/plain": [
|
||||
"Gray Images.Image with:\n",
|
||||
" data: 28×28 Array{Float64,2}\n",
|
||||
" properties:\n",
|
||||
" timedim: 0\n",
|
||||
" colorspace: Gray\n",
|
||||
" colordim: 0\n",
|
||||
" spatialorder: y x\n",
|
||||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"convert(Image,reshape(data[2][1],(28,28))/255)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainlabel(2) |> Int"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"10-element Array{Int64,1}:\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 1\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 0\n",
|
||||
" 0"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"Int.(onehot(3, 0:9))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This vector is in the same form as the model will use, where each number represents the probability of a label from 0-9."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### The Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, we're ready to set up our model! We'll use a simple logistic-regression-esque setup with two hidden layers. The hidden layers will be made of rectified linear units (128 and 64 respectively), while the output of the network will go through a softmax function in order to turn it into a probability distribution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Flux.Chain(Any[Flux.Input{1}((784,)),Flux.Dense(784,128),Flux.relu,Flux.Dense(128,64),Flux.relu,Flux.Dense(64,10),Flux.softmax],10)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m = Chain(\n",
|
||||
" Input(784), # Size of MNIST image sample\n",
|
||||
" Dense(128), relu,\n",
|
||||
" Dense( 64), relu,\n",
|
||||
" Dense( 10), softmax) # Size of output vector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice how closely this maps to our description above; we simply list each fully-connected layer and its activation in order. Notice also that we didn't have to supply the input dimension for each hidden layer; given the `Input(784)` layer, Flux is smart enough to infer the dimension of each subsequent layer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A *model* or *layer* is simply a stateful function (the state being the parameters of the network). Since the model is a function we can call it with an image to produce a prediction.\n",
|
||||
"\n",
|
||||
"(This goes both ways; above, we treated the simple Julia functions `softmax` and `relu` as models.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"10-element Array{Float64,1}:\n",
|
||||
" 0.103625 \n",
|
||||
" 0.0957417\n",
|
||||
" 0.110601 \n",
|
||||
" 0.0953211\n",
|
||||
" 0.0909375\n",
|
||||
" 0.095231 \n",
|
||||
" 0.0911746\n",
|
||||
" 0.0940216\n",
|
||||
" 0.108713 \n",
|
||||
" 0.114634 "
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m(train[31][1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This prediction basically tells us that the model has no idea what's going on – but that's not too surprising given that we haven't trained it yet. In a moment we'll show it the data we loaded above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In Flux we can take a complete model and modify its implementation after the fact. In this case, we'll convert the model into one that can run on the MXNet backend."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = mxnet(m, 784) # or `tf(m)` to run on TensorFlow"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Last of all we want to train the model on our data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accuracy(m,test) = 0.5301\n",
|
||||
"accuracy(m,test) = 0.712\n",
|
||||
"accuracy(m,test) = 0.84\n",
|
||||
"accuracy(m,test) = 0.8686\n",
|
||||
"accuracy(m,test) = 0.8894\n",
|
||||
"accuracy(m,test) = 0.8863\n",
|
||||
"accuracy(m,test) = 0.879\n",
|
||||
"accuracy(m,test) = 0.8947\n",
|
||||
"accuracy(m,test) = 0.9055\n",
|
||||
"accuracy(m,test) = 0.892\n",
|
||||
"accuracy(m,test) = 0.9029\n",
|
||||
"accuracy(m,test) = 0.8939\n",
|
||||
"accuracy(m,test) = 0.9036\n",
|
||||
"accuracy(m,test) = 0.9119\n",
|
||||
"accuracy(m,test) = 0.912\n",
|
||||
"accuracy(m,test) = 0.9062\n",
|
||||
"accuracy(m,test) = 0.9238\n",
|
||||
"accuracy(m,test) = 0.9216\n",
|
||||
"accuracy(m,test) = 0.9305\n",
|
||||
"accuracy(m,test) = 0.9366\n",
|
||||
"accuracy(m,test) = 0.921\n",
|
||||
"accuracy(m,test) = 0.9347\n",
|
||||
"accuracy(m,test) = 0.9048\n",
|
||||
"accuracy(m,test) = 0.9342\n",
|
||||
"accuracy(m,test) = 0.9267\n",
|
||||
"accuracy(m,test) = 0.9323\n",
|
||||
"accuracy(m,test) = 0.9309\n",
|
||||
"accuracy(m,test) = 0.9406\n",
|
||||
"accuracy(m,test) = 0.9299\n",
|
||||
"accuracy(m,test) = 0.9361\n",
|
||||
"accuracy(m,test) = 0.9248\n",
|
||||
"accuracy(m,test) = 0.9457\n",
|
||||
"accuracy(m,test) = 0.946\n",
|
||||
"accuracy(m,test) = 0.9449\n",
|
||||
"accuracy(m,test) = 0.9393\n",
|
||||
"accuracy(m,test) = 0.9397\n",
|
||||
"accuracy(m,test) = 0.9409\n",
|
||||
"accuracy(m,test) = 0.9403\n",
|
||||
"accuracy(m,test) = 0.948\n",
|
||||
"accuracy(m,test) = 0.9064\n",
|
||||
"accuracy(m,test) = 0.9457\n",
|
||||
"accuracy(m,test) = 0.9504\n",
|
||||
"accuracy(m,test) = 0.9482\n",
|
||||
"accuracy(m,test) = 0.9514\n",
|
||||
"accuracy(m,test) = 0.9474\n",
|
||||
"accuracy(m,test) = 0.9364\n",
|
||||
"accuracy(m,test) = 0.9423\n",
|
||||
"accuracy(m,test) = 0.9503\n",
|
||||
"accuracy(m,test) = 0.9452\n",
|
||||
"accuracy(m,test) = 0.9511\n",
|
||||
"114.461381 seconds (41.62 M allocations: 11.112 GB, 3.57% gc time)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"@time Flux.train!(model, train, test, epoch = 1, η=0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can see the accuracy being printed after every 1000 samples, and it shoots up pretty quickly to around 95%. Not state of the art, but not bad for the simplicity of the model. We might have more luck with our predictions now:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAYAAADG4PRLAAAAAXNSR0IArs4c6QAABOlJREFUeAHtnM0rbl0Yxp3j9W3mIyVJKRHCSIlSDJQyMlFMpJQyIX8BwkTMMGAiZWZsRhlIMvGRlK8Zhiby9RqcybVOnd1qrXvtdT9dp87gt9fe97rP9Tt3q/N49vmVlZX1/fObv5Qm8Ftp32z7TwIUqPyvAgVSoPIElLfPCaRA5Qkob58TSIHKE1DePieQApUnoLx9TiAFKk9AefucQApUnoDy9jmBFKg8AeXtcwIpUHkCytvnBFKg8gSUt88JpEDlCShvnxNIgcoTUN4+J5AClSegvH1OoHKB//nuv6ioCErm5+cD9/f3A7e0tACHhtXVVdjy9vYWOHbgBMZuKKE/CkwIKPZlCozdUEJ/v37Wrd4PHBoagpIdHR3/5KamJliPDW5ubqClzs5O4KenJ+DYgBMYmxHLfijQMrDYbqfA2IxY9mN9Bn5/45H59fUFW5r8+PgI6yYcHh7CpefnZ+DLy0tgW2hsbIRHJicngU2Ynp6GS8vLy8CxAScwNiOW/VCgZWCx3U6BsRmx7Mf6s9Dr62vY4u3tDXh2dhZ4d3cXWBqqqqpgi66uLuAkuL+/T7olqnVOYFQ67JuhQPvMonqCAqPSYd+M9RlYV1dnv4vgEzU1NVDdPHPb2tpg3YS9vT24tL+/Dxw7cAJjN5TQHwUmBBT7MgXGbiihP+vPQhPqeV8uLCyEmj09PcDr6+vAZWVlwEnQ3NwMt5yfnwPHDpzA2A0l9EeBCQHFvkyBsRtK6C/6M3BpaQn+CFNTU8CuYP488vX19Z8lT05OYH1rawv47u4OWBo4gdIJC9enQOGApctToHTCwvWtPwsV7uev8rW1tX9d83nB/B5oUu2+vj64pb6+Htj83uzn5yes+wZOoO9EA9ejwMCB+96OAn0nGrhe9P8ObGhogEhKSkqAbaG8vBweGR4eBt7c3ASurq4GXlxcBM7NzQU+OjoC7u7uBv74+AB2BU6ga4IpP0+BKQtw3Z4CXRNM+fnoz0DXfMz3F+fm5qDkyMgI8MPDA7AJ5nds1tbW4BZz3Xw/8uLiAu53BU6ga4IpP0+BKQtw3Z4CXRNM+fnoPwu1zae9vR0eWVhYAJ6ZmQFOOvPg5h84PT2FS9vb28DmGWh+z7SyshLudwVOoGuCKT9PgSkLcN2eAl0TTPn5jDsDzXfcCwoKIOKrqytgVzg+PoYS7+/vwBUVFcC+gRPoO9HA9SgwcOC+t6NA34kGrpdxZ2BpaSlE2NraCryzswM8Pz8PfHBwAGzC4OAgXBoYGADOyckBlgZOoHTCwvUpUDhg6fIUKJ2wcP2MOwPPzs4gMvN7n729vbBu/rzw5eUF1k0wP8vMzs42bwEeHR0F9g2cQN+JBq5HgYED970dBfpONHC9jPtOTF5eHkS4srICPDY2BuwbNjY2oOTExASw73clOIEQrz6gQH3OoGMKhDj0QcadgaYC892F4uJiuGV8fBzY9t0L8+eB5v/VZv4f47CZB+AEeggxzRIUmGb6HvamQA8hplki48/ANMMNsTcnMETKgntQoGC4IUpTYIiUBfegQMFwQ5SmwBApC+5BgYLhhihNgSFSFtyDAgXDDVGaAkOkLLgHBQqGG6I0BYZIWXAPChQMN0RpCgyRsuAeFCgYbojSFBgiZcE9KFAw3BClKTBEyoJ7UKBguCFKU2CIlAX3+B+3gbSWnpc3pwAAAABJRU5ErkJggg==",
|
||||
"text/plain": [
|
||||
"Gray Images.Image with:\n",
|
||||
" data: 28×28 Array{Float64,2}\n",
|
||||
" properties:\n",
|
||||
" timedim: 0\n",
|
||||
" colorspace: Gray\n",
|
||||
" colordim: 0\n",
|
||||
" spatialorder: y x\n",
|
||||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"convert(Image,reshape(data[31][1],(28,28))'/255)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"10×1 Array{Float32,2}:\n",
|
||||
" 0.0 \n",
|
||||
" 0.002\n",
|
||||
" 0.001\n",
|
||||
" 0.996\n",
|
||||
" 0.0 \n",
|
||||
" 0.0 \n",
|
||||
" 0.0 \n",
|
||||
" 0.001\n",
|
||||
" 0.001\n",
|
||||
" 0.0 "
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"map(x->round(x, 3), model(data[31][1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"onecold(model(data[31][1]), 0:9)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As you can see from the output probability distribution above, our model is 98% confident that the digit is a 3."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Under the hood"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"A \"model\" is some function-like object which supports `back!(model, Δ, inputs...)` in order to calculate gradients. Flux makes it really easy to define custom layer types, and for the most part can do a lot of the work for you. For example, we could implement a perceptron layer as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@model type Perceptron\n",
|
||||
" W\n",
|
||||
" b\n",
|
||||
" x -> σ(W*x + b)\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"Perceptron(in::Integer, out::Integer) =\n",
|
||||
" Perceptron(randn(out, in), randn(out))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is just the equation you'd find in a textbook! Because `Perceptron` is a simple Julia type, we can define a handy convenience constructor to initialise the weights as well. (This is almost exactly how `Dense` is defined in Flux itself, by the way – no special casing here.)\n",
|
||||
"\n",
|
||||
"The `@model` macro is fairly simple, and just extends our basic type definition with some useful functionality like the backward pass. It's easy to see what's going on:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"using MacroTools\n",
|
||||
"\n",
|
||||
"@expand(@model type Perceptron\n",
|
||||
" W\n",
|
||||
" b\n",
|
||||
" x -> σ.(W*x + b)\n",
|
||||
"end) |> longdef |> prettify"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type Perceptron <: Model\n",
|
||||
" W\n",
|
||||
" b\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"function Perceptron(W::AArray,b::AArray)\n",
|
||||
" Perceptron(param(W),param(b)) # `param` stores the parameter and its gradient together\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"function (self::Perceptron)(x)\n",
|
||||
" σ(state(self.W) * x + state(self.b))\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"function back!(self::Perceptron,Δ,x)\n",
|
||||
" owl = (back!(σ,Δ,state(self.W) * x + state(self.b)))[1]\n",
|
||||
" gerbil = back!(*,owl,state(self.W),x)\n",
|
||||
" accumulate!(self.W,gerbil[1])\n",
|
||||
" accumulate!(self.b,owl)\n",
|
||||
" (gerbil[2],)\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"function update!(self::Perceptron,η)\n",
|
||||
" update!(self.W,η)\n",
|
||||
" update!(self.b,η)\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"function graph(::Perceptron)\n",
|
||||
" vertex(σ,vertex(+,vertex(*,constant(Flux.ModelInput(:W)),constant(Flux.ModelInput(1))),constant(Flux.ModelInput(:b))))\n",
|
||||
"end"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This is exactly the sort of code we might have written by hand if we were using something like Torch. The forward pass is recognisable as the function we wrote in the original definition and `back!` implements the gradient pass. The main odd feature here is the `graph` function; this returns the graph representation of the model, which allows Flux to compile it for other backends.\n",
|
||||
"\n",
|
||||
"Models can contain other models, so you could just as well implement the above by reusing Flux's built-in `Dense` layer:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@model type Perceptron\n",
|
||||
" dense::Dense\n",
|
||||
" x -> σ(dense(x))\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"Perceptron(in, out) = Perceptron(Dense(in, out))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This allows you to build up models of arbitrary complexity in a completely declarative way."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Because models are just Julia objects conforming to a protocol, it's easy to implement *completely* custom models, which don't use the `@model` macro. For example, perhaps the function isn't easy to auto-differentiate, or you need a custom GPU kernel. An example of the former kind is the `Chain` model used above, which is defined roughly like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type Chain <: Model\n",
|
||||
" layers::Vector{Any}\n",
|
||||
" Chain(ms...) = Chain(collect(ms))\n",
|
||||
"end\n",
|
||||
"\n",
|
||||
"(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)\n",
|
||||
"back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)\n",
|
||||
"update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Likewise, the `relu` function is defined as:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"relu(x) = max(0, x)\n",
|
||||
"back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Julia 0.5.0-rc3",
|
||||
"language": "julia",
|
||||
"name": "julia-0.5"
|
||||
},
|
||||
"language_info": {
|
||||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "0.5.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
Loading…
Reference in New Issue