188 lines
3.7 KiB
Plaintext
188 lines
3.7 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Auto grad\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torchvision"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([5., 7., 9.], grad_fn=<AddBackward0>)\n",
|
||
|
"<AddBackward0 object at 0x11f573ad0>\n",
|
||
|
"tensor(21., grad_fn=<SumBackward0>)\n",
|
||
|
"<SumBackward0 object at 0x11f595650>\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# The tensor object keeps track of how it was created if requieres_grad is True \n",
|
||
|
"x = torch.tensor([1.,2.,3],requires_grad=True)\n",
|
||
|
"y = torch.tensor([4.,5.,6],requires_grad=True)\n",
|
||
|
"\n",
|
||
|
"z = x + y\n",
|
||
|
"print(z)\n",
|
||
|
"\n",
|
||
|
"print(z.grad_fn)\n",
|
||
|
"s = z.sum()\n",
|
||
|
"\n",
|
||
|
"print(s)\n",
|
||
|
"\n",
|
||
|
"print(s.grad_fn)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([2., 2., 2.])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# To back propagate\n",
|
||
|
"s.backward()\n",
|
||
|
"print(x.grad)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"False False\n",
|
||
|
"None\n",
|
||
|
"<AddBackward0 object at 0x11f088650>\n",
|
||
|
"True\n",
|
||
|
"None\n",
|
||
|
"True\n",
|
||
|
"True\n",
|
||
|
"False\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.randn(2,2)\n",
|
||
|
"y = torch.randn(2,2)\n",
|
||
|
"print(x.requires_grad,y.requires_grad)\n",
|
||
|
"\n",
|
||
|
"z = x + y\n",
|
||
|
"\n",
|
||
|
"print(z.grad_fn)\n",
|
||
|
"\n",
|
||
|
"x.requires_grad_()\n",
|
||
|
"y.requires_grad_()\n",
|
||
|
"\n",
|
||
|
"z= x + y\n",
|
||
|
"\n",
|
||
|
"print(z.grad_fn)\n",
|
||
|
"\n",
|
||
|
"print(z.requires_grad)\n",
|
||
|
"\n",
|
||
|
"new_z = z.detach()\n",
|
||
|
"\n",
|
||
|
"print(new_z.grad_fn)\n",
|
||
|
"\n",
|
||
|
"print(x.requires_grad)\n",
|
||
|
"print((x+10).requires_grad)\n",
|
||
|
"\n",
|
||
|
"with torch.no_grad():\n",
|
||
|
" print((x+10).requires_grad)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor([[1., 1.],\n",
|
||
|
" [1., 1.]], requires_grad=True)\n",
|
||
|
"tensor([[3., 3.],\n",
|
||
|
" [3., 3.]], grad_fn=<AddBackward0>)\n",
|
||
|
"<AddBackward0 object at 0x11ee83b90>\n",
|
||
|
"tensor([[27., 27.],\n",
|
||
|
" [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)\n",
|
||
|
"tensor([[4.5000, 4.5000],\n",
|
||
|
" [4.5000, 4.5000]])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x = torch.ones(2,2,requires_grad=True)\n",
|
||
|
"print(x)\n",
|
||
|
"y = x + 2\n",
|
||
|
"print(y)\n",
|
||
|
"print(y.grad_fn)\n",
|
||
|
"\n",
|
||
|
"z = y*y*3\n",
|
||
|
"\n",
|
||
|
"out = z.mean()\n",
|
||
|
"\n",
|
||
|
"print(z,out)\n",
|
||
|
"\n",
|
||
|
"out.backward()\n",
|
||
|
"print(x.grad)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|