update notebook
This commit is contained in:
parent
c9f9665e4e
commit
af8001bdb8
|
@ -23,7 +23,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
|
@ -70,7 +70,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -89,7 +89,7 @@
|
|||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -100,7 +100,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -111,7 +111,7 @@
|
|||
"3"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -122,7 +122,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -141,7 +141,7 @@
|
|||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 32,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -152,7 +152,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -163,7 +163,7 @@
|
|||
"0"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -174,7 +174,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -195,7 +195,7 @@
|
|||
" 0"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -234,7 +234,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -242,10 +242,10 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Flux.Chain(Any[Flux.Input{1}((784,)),Flux.Dense(784,128),Flux.relu,Flux.Dense(128,64),Flux.relu,Flux.Dense(64,10),Flux.softmax],10)"
|
||||
"Chain(Flux.Input{1}((784,)), Dense(784,128), Flux.relu, Dense(128,64), Flux.relu, Dense(64,10), Flux.softmax)"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -276,7 +276,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -284,26 +284,17 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"10-element Array{Float64,1}:\n",
|
||||
" 0.103625 \n",
|
||||
" 0.0957417\n",
|
||||
" 0.110601 \n",
|
||||
" 0.0953211\n",
|
||||
" 0.0909375\n",
|
||||
" 0.095231 \n",
|
||||
" 0.0911746\n",
|
||||
" 0.0940216\n",
|
||||
" 0.108713 \n",
|
||||
" 0.114634 "
|
||||
"1×10 Array{Float64,2}:\n",
|
||||
" 0.100097 0.0998791 0.0998511 0.0998956 … 0.100189 0.100018 0.100032"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"m(train[31][1])"
|
||||
"m(train[31][1]')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -322,7 +313,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -330,16 +321,16 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
|
||||
"Flux.TF.Model(Session(Ptr{Void} @0x000000012cbbf290),TensorFlow.Tensor[<Tensor placeholder:1 shape=unknown dtype=Float32>],<Tensor node_15:1 shape=unknown dtype=Float32>,<Tensor gradients/MatMul_grad/MatMul:1 shape=unknown dtype=Float32>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = mxnet(m, 784) # or `tf(m)` to run on TensorFlow"
|
||||
"model = tf(m) # or `tf(m)` to run on TensorFlow"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -351,78 +342,45 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accuracy(m,test) = 0.5301\n",
|
||||
"accuracy(m,test) = 0.712\n",
|
||||
"accuracy(m,test) = 0.84\n",
|
||||
"accuracy(m,test) = 0.8686\n",
|
||||
"accuracy(m,test) = 0.8894\n",
|
||||
"accuracy(m,test) = 0.8863\n",
|
||||
"accuracy(m,test) = 0.879\n",
|
||||
"accuracy(m,test) = 0.8947\n",
|
||||
"accuracy(m,test) = 0.9055\n",
|
||||
"accuracy(m,test) = 0.892\n",
|
||||
"accuracy(m,test) = 0.9029\n",
|
||||
"accuracy(m,test) = 0.8939\n",
|
||||
"accuracy(m,test) = 0.9036\n",
|
||||
"accuracy(m,test) = 0.9119\n",
|
||||
"accuracy(m,test) = 0.912\n",
|
||||
"accuracy(m,test) = 0.9062\n",
|
||||
"accuracy(m,test) = 0.9238\n",
|
||||
"accuracy(m,test) = 0.9216\n",
|
||||
"accuracy(m,test) = 0.9305\n",
|
||||
"accuracy(m,test) = 0.9366\n",
|
||||
"accuracy(m,test) = 0.921\n",
|
||||
"accuracy(m,test) = 0.9347\n",
|
||||
"accuracy(m,test) = 0.9048\n",
|
||||
"accuracy(m,test) = 0.9342\n",
|
||||
"accuracy(m,test) = 0.9267\n",
|
||||
"accuracy(m,test) = 0.9323\n",
|
||||
"accuracy(m,test) = 0.9309\n",
|
||||
"accuracy(m,test) = 0.9406\n",
|
||||
"accuracy(m,test) = 0.9299\n",
|
||||
"accuracy(m,test) = 0.9361\n",
|
||||
"accuracy(m,test) = 0.9248\n",
|
||||
"accuracy(m,test) = 0.9457\n",
|
||||
"accuracy(m,test) = 0.946\n",
|
||||
"accuracy(m,test) = 0.9449\n",
|
||||
"accuracy(m,test) = 0.9393\n",
|
||||
"accuracy(m,test) = 0.9397\n",
|
||||
"accuracy(m,test) = 0.9409\n",
|
||||
"accuracy(m,test) = 0.9403\n",
|
||||
"accuracy(m,test) = 0.948\n",
|
||||
"accuracy(m,test) = 0.9064\n",
|
||||
"accuracy(m,test) = 0.9457\n",
|
||||
"accuracy(m,test) = 0.9504\n",
|
||||
"accuracy(m,test) = 0.9482\n",
|
||||
"accuracy(m,test) = 0.9514\n",
|
||||
"accuracy(m,test) = 0.9474\n",
|
||||
"accuracy(m,test) = 0.9364\n",
|
||||
"accuracy(m,test) = 0.9423\n",
|
||||
"accuracy(m,test) = 0.9503\n",
|
||||
"accuracy(m,test) = 0.9452\n",
|
||||
"accuracy(m,test) = 0.9511\n",
|
||||
"114.461381 seconds (41.62 M allocations: 11.112 GB, 3.57% gc time)\n"
|
||||
"INFO: Epoch 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"MXModel((Flux.softmax)((Flux.Dense(64,10))((Flux.relu)((Flux.Dense(128,64))((Flux.relu)((Flux.Dense(784,128))((Flux.Input{1}((784,)))(Flux.ModelInput(1)))))))))"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"y = Float32[0.100094 0.099893 0.0998293 0.099919 0.0999388 0.100123 0.0999624 0.100185 0.100024 0.100031]\n",
|
||||
"accuracy(m,test) = 0.109\n",
|
||||
"y = Float32[0.0999111 0.100465 0.0997179 0.0998629 0.100282 0.0994788 0.0999943 0.100667 0.0996533 0.0999675]\n",
|
||||
"accuracy(m,test) = 0.109\n",
|
||||
"y = Float32[0.103446 0.10002 0.100398 0.102059 0.0977901 0.0989413 0.0995703 0.0993392 0.100129 0.0983073]\n",
|
||||
"accuracy(m,test) = 0.1779\n",
|
||||
"y = Float32[0.0288616 0.136275 0.0976122 0.103702 0.0972153 0.113357 0.0834434 0.0843929 0.153776 0.101365]\n",
|
||||
"accuracy(m,test) = 0.3236\n",
|
||||
"y = Float32[1.94426f-6 8.36834f-11 1.50476f-8 5.13133f-8 0.0209049 0.887724 0.000472463 2.70319f-7 0.0901923 0.000703486]\n",
|
||||
"accuracy(m,test) = 0.8258\n",
|
||||
"y = Float32[0.00507439 0.0108996 0.0307061 0.364342 0.00413664 0.0342019 0.000766874 0.364772 0.0279412 0.157159]\n",
|
||||
"accuracy(m,test) = 0.8543\n",
|
||||
"y = Float32[6.40463f-14 2.77292f-10 2.8507f-9 0.999999 8.78815f-15 5.18649f-7 1.1916f-16 5.82441f-11 7.00438f-8 3.48348f-11]\n",
|
||||
"accuracy(m,test) = 0.7922\n",
|
||||
"y = Float32[1.47776f-7 0.997464 0.000132982 0.000113373 4.29683f-5 6.14733f-7 1.82552f-5 7.3389f-6 0.00221002 1.02455f-5]\n",
|
||||
"accuracy(m,test) = 0.9133\n",
|
||||
"y = Float32[4.12331f-9 4.38521f-13 2.32062f-8 1.82065f-8 8.9981f-12 2.4413f-6 5.51279f-13 0.999988 9.02321f-11 9.3387f-6]\n",
|
||||
"accuracy(m,test) = 0.9151\n",
|
||||
"y = Float32[5.29229f-8 0.000115183 4.59078f-7 0.887777 8.759f-6 0.0215426 7.45297f-8 7.17829f-6 0.0854197 0.00512896]\n",
|
||||
"accuracy(m,test) = 0.8715\n",
|
||||
"149.428110 seconds (27.60 M allocations: 2.383 GB, 2.10% gc time)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
@ -438,7 +396,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -457,7 +415,7 @@
|
|||
" pixelspacing: 1.0 1.0"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -466,38 +424,6 @@
|
|||
"convert(Image,reshape(data[31][1],(28,28))'/255)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"10×1 Array{Float32,2}:\n",
|
||||
" 0.0 \n",
|
||||
" 0.002\n",
|
||||
" 0.001\n",
|
||||
" 0.996\n",
|
||||
" 0.0 \n",
|
||||
" 0.0 \n",
|
||||
" 0.0 \n",
|
||||
" 0.001\n",
|
||||
" 0.001\n",
|
||||
" 0.0 "
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"map(x->round(x, 3), model(data[31][1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
|
@ -508,7 +434,17 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3"
|
||||
"10-element Array{Float32,1}:\n",
|
||||
" 0.0 \n",
|
||||
" 0.019\n",
|
||||
" 0.004\n",
|
||||
" 0.966\n",
|
||||
" 0.0 \n",
|
||||
" 0.001\n",
|
||||
" 0.0 \n",
|
||||
" 0.004\n",
|
||||
" 0.004\n",
|
||||
" 0.002"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
|
@ -516,6 +452,28 @@
|
|||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"map(x->round(x, 3), model(data[31][1]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"3"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"onecold(model(data[31][1]), 0:9)"
|
||||
]
|
||||
|
@ -524,7 +482,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As you can see from the output probability distribution above, our model is 98% confident that the digit is a 3."
|
||||
"As you can see from the output probability distribution above, our model is 97% confident that the digit is a 3."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -549,7 +507,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@model type Perceptron\n",
|
||||
"@net type Perceptron\n",
|
||||
" W\n",
|
||||
" b\n",
|
||||
" x -> σ(W*x + b)\n",
|
||||
|
@ -565,7 +523,7 @@
|
|||
"source": [
|
||||
"This is just the equation you'd find in a textbook! Because `Perceptron` is a simple Julia type, we can define a handy convenience constructor to initialise the weights as well. (This is almost exactly how `Dense` is defined in Flux itself, by the way – no special casing here.)\n",
|
||||
"\n",
|
||||
"The `@model` macro is fairly simple, and just extends our basic type definition with some useful functionality like the backward pass. It's easy to see what's going on:"
|
||||
"The `@net` macro is fairly simple, and just extends our basic type definition with some useful functionality like the backward pass. It's easy to see what's going on:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -578,7 +536,7 @@
|
|||
"source": [
|
||||
"using MacroTools\n",
|
||||
"\n",
|
||||
"@expand(@model type Perceptron\n",
|
||||
"@expand(@net type Perceptron\n",
|
||||
" W\n",
|
||||
" b\n",
|
||||
" x -> σ.(W*x + b)\n",
|
||||
|
@ -641,7 +599,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@model type Perceptron\n",
|
||||
"@net type Perceptron\n",
|
||||
" dense::Dense\n",
|
||||
" x -> σ(dense(x))\n",
|
||||
"end\n",
|
||||
|
@ -660,7 +618,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Because models are just Julia objects conforming to a protocol, it's easy to implement *completely* custom models, which don't use the `@model` macro. For example, perhaps the function isn't easy to auto-differentiate, or you need a custom GPU kernel. An example of the former kind is the `Chain` model used above, which is defined roughly like so:"
|
||||
"Because models are just Julia objects conforming to a protocol, it's easy to implement *completely* custom models, which don't use the `@net` macro. For example, perhaps the function isn't easy to auto-differentiate, or you need a custom GPU kernel. An example of the former kind is the `Chain` model used above, which is defined roughly like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -711,7 +669,7 @@
|
|||
"file_extension": ".jl",
|
||||
"mimetype": "application/julia",
|
||||
"name": "julia",
|
||||
"version": "0.5.0"
|
||||
"version": "0.5.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
Loading…
Reference in New Issue