181 lines
4.3 KiB
Plaintext
181 lines
4.3 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torchvision"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.functional as F"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LeNet(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(LeNet, self).__init__()\n",
|
|
" # 1 input image channel, 6 output channels, 3x3 square conv kernel\n",
|
|
" self.conv1 = nn.Conv2d(1, 6, 3)\n",
|
|
" self.conv2 = nn.Conv2d(6, 16, 3)\n",
|
|
" self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension\n",
|
|
" self.fc2 = nn.Linear(120, 84)\n",
|
|
" self.fc3 = nn.Linear(84, 10)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
|
|
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
|
|
" x = x.view(-1, int(x.nelement() / x.shape[0]))\n",
|
|
" x = F.relu(self.fc1(x))\n",
|
|
" x = F.relu(self.fc2(x))\n",
|
|
" x = self.fc3(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = LeNet().to(device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[('weight', Parameter containing:\n",
|
|
"tensor([[[[ 0.0854, 0.1762, -0.0814],\n",
|
|
" [-0.0859, -0.2473, -0.3291],\n",
|
|
" [-0.2670, -0.0962, -0.2002]]],\n",
|
|
"\n",
|
|
"\n",
|
|
" [[[-0.2353, 0.2884, -0.0267],\n",
|
|
" [-0.2879, 0.0165, -0.2282],\n",
|
|
" [-0.2628, -0.0222, 0.2331]]],\n",
|
|
"\n",
|
|
"\n",
|
|
" [[[ 0.1070, 0.2683, 0.2301],\n",
|
|
" [ 0.1637, 0.2600, 0.1026],\n",
|
|
" [ 0.3013, -0.2336, 0.0121]]],\n",
|
|
"\n",
|
|
"\n",
|
|
" [[[-0.2488, 0.2017, -0.1913],\n",
|
|
" [-0.0723, 0.0911, -0.2306],\n",
|
|
" [-0.0578, 0.3193, 0.1617]]],\n",
|
|
"\n",
|
|
"\n",
|
|
" [[[-0.0321, 0.1383, 0.2564],\n",
|
|
" [-0.1842, -0.3065, 0.1342],\n",
|
|
" [ 0.0815, -0.0757, 0.2290]]],\n",
|
|
"\n",
|
|
"\n",
|
|
" [[[ 0.1121, -0.2876, -0.1017],\n",
|
|
" [ 0.2958, 0.3092, -0.2187],\n",
|
|
" [ 0.3065, 0.1924, 0.2755]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:\n",
|
|
"tensor([-0.1582, -0.0629, -0.1216, -0.3196, 0.2949, 0.2759], device='cuda:0',\n",
|
|
" requires_grad=True))]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"module = model.conv1\n",
|
|
"print(list(module.named_parameters()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(list(module.named_buffers()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn.utils.prune as prune"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"prune.random_unstructured"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|