update notebook

This commit is contained in:
Mike J Innes 2016-10-12 17:07:56 +01:00
parent c9f9665e4e
commit af8001bdb8
1 changed files with 93 additions and 135 deletions

View File

@ -23,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": true
@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 3,
"metadata": {
"collapsed": false
},
@ -89,7 +89,7 @@
" pixelspacing: 1.0 1.0"
]
},
"execution_count": 24,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 4,
"metadata": {
"collapsed": false
},
@ -111,7 +111,7 @@
"3"
]
},
"execution_count": 34,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -122,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 5,
"metadata": {
"collapsed": false
},
@ -141,7 +141,7 @@
" pixelspacing: 1.0 1.0"
]
},
"execution_count": 32,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -152,7 +152,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 6,
"metadata": {
"collapsed": false
},
@ -163,7 +163,7 @@
"0"
]
},
"execution_count": 35,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -174,7 +174,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {
"collapsed": false
},
@ -195,7 +195,7 @@
" 0"
]
},
"execution_count": 12,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -234,7 +234,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 9,
"metadata": {
"collapsed": false
},
@ -242,10 +242,10 @@
{
"data": {
"text/plain": [
"Flux.Chain(Any[Flux.Input{1}((784,)),Flux.Dense(784,128),Flux.relu,Flux.Dense(128,64),Flux.relu,Flux.Dense(64,10),Flux.softmax],10)"
"Chain(Flux.Input{1}((784,)), Dense(784,128), Flux.relu, Dense(128,64), Flux.relu, Dense(64,10), Flux.softmax)"
]
},
"execution_count": 14,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -276,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {
"collapsed": false
},
@ -284,26 +284,17 @@
{
"data": {
"text/plain": [
"10-element Array{Float64,1}:\n",
" 0.103625 \n",
" 0.0957417\n",
" 0.110601 \n",
" 0.0953211\n",
" 0.0909375\n",
" 0.095231 \n",
" 0.0911746\n",
" 0.0940216\n",
" 0.108713 \n",
" 0.114634 "
"1×10 Array{Float64,2}:\n",
" 0.100097 0.0998791 0.0998511 0.0998956 … 0.100189 0.100018 0.100032"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m(train[31][1])"
"m(train[31][1]')"
]
},
{
@ -322,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 13,
"metadata": {
"collapsed": false
},
@ -330,16 +321,16 @@
{
"data": {
"text/plain": [
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
"Flux.TF.Model(Session(Ptr{Void} @0x000000012cbbf290),TensorFlow.Tensor[<Tensor placeholder:1 shape=unknown dtype=Float32>],<Tensor node_15:1 shape=unknown dtype=Float32>,<Tensor gradients/MatMul_grad/MatMul:1 shape=unknown dtype=Float32>)"
]
},
"execution_count": 8,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = mxnet(m, 784) # or `tf(m)` to run on TensorFlow"
"model = tf(m) # or `tf(m)` to run on TensorFlow"
]
},
{
@ -351,78 +342,45 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 14,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"name": "stderr",
"output_type": "stream",
"text": [
"accuracy(m,test) = 0.5301\n",
"accuracy(m,test) = 0.712\n",
"accuracy(m,test) = 0.84\n",
"accuracy(m,test) = 0.8686\n",
"accuracy(m,test) = 0.8894\n",
"accuracy(m,test) = 0.8863\n",
"accuracy(m,test) = 0.879\n",
"accuracy(m,test) = 0.8947\n",
"accuracy(m,test) = 0.9055\n",
"accuracy(m,test) = 0.892\n",
"accuracy(m,test) = 0.9029\n",
"accuracy(m,test) = 0.8939\n",
"accuracy(m,test) = 0.9036\n",
"accuracy(m,test) = 0.9119\n",
"accuracy(m,test) = 0.912\n",
"accuracy(m,test) = 0.9062\n",
"accuracy(m,test) = 0.9238\n",
"accuracy(m,test) = 0.9216\n",
"accuracy(m,test) = 0.9305\n",
"accuracy(m,test) = 0.9366\n",
"accuracy(m,test) = 0.921\n",
"accuracy(m,test) = 0.9347\n",
"accuracy(m,test) = 0.9048\n",
"accuracy(m,test) = 0.9342\n",
"accuracy(m,test) = 0.9267\n",
"accuracy(m,test) = 0.9323\n",
"accuracy(m,test) = 0.9309\n",
"accuracy(m,test) = 0.9406\n",
"accuracy(m,test) = 0.9299\n",
"accuracy(m,test) = 0.9361\n",
"accuracy(m,test) = 0.9248\n",
"accuracy(m,test) = 0.9457\n",
"accuracy(m,test) = 0.946\n",
"accuracy(m,test) = 0.9449\n",
"accuracy(m,test) = 0.9393\n",
"accuracy(m,test) = 0.9397\n",
"accuracy(m,test) = 0.9409\n",
"accuracy(m,test) = 0.9403\n",
"accuracy(m,test) = 0.948\n",
"accuracy(m,test) = 0.9064\n",
"accuracy(m,test) = 0.9457\n",
"accuracy(m,test) = 0.9504\n",
"accuracy(m,test) = 0.9482\n",
"accuracy(m,test) = 0.9514\n",
"accuracy(m,test) = 0.9474\n",
"accuracy(m,test) = 0.9364\n",
"accuracy(m,test) = 0.9423\n",
"accuracy(m,test) = 0.9503\n",
"accuracy(m,test) = 0.9452\n",
"accuracy(m,test) = 0.9511\n",
"114.461381 seconds (41.62 M allocations: 11.112 GB, 3.57% gc time)\n"
"INFO: Epoch 1\n"
]
},
{
"data": {
"text/plain": [
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"y = Float32[0.100094 0.099893 0.0998293 0.099919 0.0999388 0.100123 0.0999624 0.100185 0.100024 0.100031]\n",
"accuracy(m,test) = 0.109\n",
"y = Float32[0.0999111 0.100465 0.0997179 0.0998629 0.100282 0.0994788 0.0999943 0.100667 0.0996533 0.0999675]\n",
"accuracy(m,test) = 0.109\n",
"y = Float32[0.103446 0.10002 0.100398 0.102059 0.0977901 0.0989413 0.0995703 0.0993392 0.100129 0.0983073]\n",
"accuracy(m,test) = 0.1779\n",
"y = Float32[0.0288616 0.136275 0.0976122 0.103702 0.0972153 0.113357 0.0834434 0.0843929 0.153776 0.101365]\n",
"accuracy(m,test) = 0.3236\n",
"y = Float32[1.94426f-6 8.36834f-11 1.50476f-8 5.13133f-8 0.0209049 0.887724 0.000472463 2.70319f-7 0.0901923 0.000703486]\n",
"accuracy(m,test) = 0.8258\n",
"y = Float32[0.00507439 0.0108996 0.0307061 0.364342 0.00413664 0.0342019 0.000766874 0.364772 0.0279412 0.157159]\n",
"accuracy(m,test) = 0.8543\n",
"y = Float32[6.40463f-14 2.77292f-10 2.8507f-9 0.999999 8.78815f-15 5.18649f-7 1.1916f-16 5.82441f-11 7.00438f-8 3.48348f-11]\n",
"accuracy(m,test) = 0.7922\n",
"y = Float32[1.47776f-7 0.997464 0.000132982 0.000113373 4.29683f-5 6.14733f-7 1.82552f-5 7.3389f-6 0.00221002 1.02455f-5]\n",
"accuracy(m,test) = 0.9133\n",
"y = Float32[4.12331f-9 4.38521f-13 2.32062f-8 1.82065f-8 8.9981f-12 2.4413f-6 5.51279f-13 0.999988 9.02321f-11 9.3387f-6]\n",
"accuracy(m,test) = 0.9151\n",
"y = Float32[5.29229f-8 0.000115183 4.59078f-7 0.887777 8.759f-6 0.0215426 7.45297f-8 7.17829f-6 0.0854197 0.00512896]\n",
"accuracy(m,test) = 0.8715\n",
"149.428110 seconds (27.60 M allocations: 2.383 GB, 2.10% gc time)\n"
]
}
],
"source": [
@ -438,7 +396,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {
"collapsed": false
},
@ -457,7 +415,7 @@
" pixelspacing: 1.0 1.0"
]
},
"execution_count": 12,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@ -466,38 +424,6 @@
"convert(Image,reshape(data[31][1],(28,28))'/255)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"10×1 Array{Float32,2}:\n",
" 0.0 \n",
" 0.002\n",
" 0.001\n",
" 0.996\n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.001\n",
" 0.001\n",
" 0.0 "
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"map(x->round(x, 3), model(data[31][1]))"
]
},
{
"cell_type": "code",
"execution_count": 16,
@ -508,7 +434,17 @@
{
"data": {
"text/plain": [
"3"
"10-element Array{Float32,1}:\n",
" 0.0 \n",
" 0.019\n",
" 0.004\n",
" 0.966\n",
" 0.0 \n",
" 0.001\n",
" 0.0 \n",
" 0.004\n",
" 0.004\n",
" 0.002"
]
},
"execution_count": 16,
@ -516,6 +452,28 @@
"output_type": "execute_result"
}
],
"source": [
"map(x->round(x, 3), model(data[31][1]))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onecold(model(data[31][1]), 0:9)"
]
@ -524,7 +482,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see from the output probability distribution above, our model is 98% confident that the digit is a 3."
"As you can see from the output probability distribution above, our model is 97% confident that the digit is a 3."
]
},
{
@ -549,7 +507,7 @@
},
"outputs": [],
"source": [
"@model type Perceptron\n",
"@net type Perceptron\n",
" W\n",
" b\n",
" x -> σ(W*x + b)\n",
@ -565,7 +523,7 @@
"source": [
"This is just the equation you'd find in a textbook! Because `Perceptron` is a simple Julia type, we can define a handy convenience constructor to initialise the weights as well. (This is almost exactly how `Dense` is defined in Flux itself, by the way no special casing here.)\n",
"\n",
"The `@model` macro is fairly simple, and just extends our basic type definition with some useful functionality like the backward pass. It's easy to see what's going on:"
"The `@net` macro is fairly simple, and just extends our basic type definition with some useful functionality like the backward pass. It's easy to see what's going on:"
]
},
{
@ -578,7 +536,7 @@
"source": [
"using MacroTools\n",
"\n",
"@expand(@model type Perceptron\n",
"@expand(@net type Perceptron\n",
" W\n",
" b\n",
" x -> σ.(W*x + b)\n",
@ -641,7 +599,7 @@
},
"outputs": [],
"source": [
"@model type Perceptron\n",
"@net type Perceptron\n",
" dense::Dense\n",
" x -> σ(dense(x))\n",
"end\n",
@ -660,7 +618,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Because models are just Julia objects conforming to a protocol, it's easy to implement *completely* custom models, which don't use the `@model` macro. For example, perhaps the function isn't easy to auto-differentiate, or you need a custom GPU kernel. An example of the former kind is the `Chain` model used above, which is defined roughly like so:"
"Because models are just Julia objects conforming to a protocol, it's easy to implement *completely* custom models, which don't use the `@net` macro. For example, perhaps the function isn't easy to auto-differentiate, or you need a custom GPU kernel. An example of the former kind is the `Chain` model used above, which is defined roughly like so:"
]
},
{
@ -711,7 +669,7 @@
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.5.0"
"version": "0.5.1"
}
},
"nbformat": 4,