PyTorch에서 모델 요약을 인쇄하려면 어떻게 해야 합니까?
PyTorch에서 모델 요약을 인쇄하려면 어떻게 해야 합니까?model.summary()
는 Keras에서 다음을 수행합니다.
Model Summary:
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 1, 15, 27) 0
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D) (None, 8, 15, 27) 872 input_1[0][0]
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D) (None, 8, 7, 27) 0 convolution2d_1[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 1512) 0 maxpooling2d_1[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 1) 1513 flatten_1[0][0]
====================================================================================================
Total params: 2,385
Trainable params: 2,385
Non-trainable params: 0
네, pytorch-summary 패키지를 사용하면 정확한 Keras 표현을 얻을 수 있습니다.
VGG16의 예:
from torchvision import models
from torchsummary import summary
vgg = models.vgg16()
summary(vgg, (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Linear-32 [-1, 4096] 102,764,544
ReLU-33 [-1, 4096] 0
Dropout-34 [-1, 4096] 0
Linear-35 [-1, 4096] 16,781,312
ReLU-36 [-1, 4096] 0
Dropout-37 [-1, 4096] 0
Linear-38 [-1, 1000] 4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------
Keras 모델만큼 자세한 모델 정보는 얻을 수 없습니다.요약하면 모델을 인쇄하는 것만으로 관련된 다양한 레이어와 그 사양에 대해 알 수 있습니다.
예:
from torchvision import models
model = models.vgg16()
print(model)
이 경우의 출력은 다음과 같습니다.
VGG (
(features): Sequential (
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU (inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU (inplace)
(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU (inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU (inplace)
(9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU (inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU (inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU (inplace)
(16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU (inplace)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU (inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU (inplace)
(23): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU (inplace)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU (inplace)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU (inplace)
(30): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
)
(classifier): Sequential (
(0): Dropout (p = 0.5)
(1): Linear (25088 -> 4096)
(2): ReLU (inplace)
(3): Dropout (p = 0.5)
(4): Linear (4096 -> 4096)
(5): ReLU (inplace)
(6): Linear (4096 -> 1000)
)
)
Kashyap이 말한 것처럼, 이제 당신은 그 정보를state_dict
여러 레이어의 가중치를 얻는 방법.그러나 이 레이어 목록을 사용하면 Keras가 모델 요약을 좋아하도록 도우미 기능을 만드는 것이 더 많은 방향을 제공할 수 있습니다.이게 도움이 됐으면 좋겠네요!
토치 요약 유형을 사용하려면:
from torchsummary import summary
없으면 먼저 설치합니다.
pip install torchsummary
그 후 시험해 볼 수 있지만, 어떤 이유로든 모델을 cuda로 설정하지 않으면 동작하지 않습니다.alexnet.cuda
:
from torchsummary import summary
help(summary)
import torchvision.models as models
alexnet = models.alexnet(pretrained=False)
alexnet.cuda()
summary(alexnet, (3, 224, 224))
print(alexnet)
그summary
입력 사이즈를 취득할 필요가 있습니다.배치 사이즈는 -1로 설정되어 있습니다.이는 당사가 제공하는 배치사이즈를 의미합니다.
설정했을 경우summary(alexnet, (3, 224, 224), 32)
즉,bs=32
.
summary(model, input_size, batch_size=-1, device='cuda')
출력:
Help on function summary in module torchsummary.torchsummary:
summary(model, input_size, batch_size=-1, device='cuda')
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [32, 64, 55, 55] 23,296
ReLU-2 [32, 64, 55, 55] 0
MaxPool2d-3 [32, 64, 27, 27] 0
Conv2d-4 [32, 192, 27, 27] 307,392
ReLU-5 [32, 192, 27, 27] 0
MaxPool2d-6 [32, 192, 13, 13] 0
Conv2d-7 [32, 384, 13, 13] 663,936
ReLU-8 [32, 384, 13, 13] 0
Conv2d-9 [32, 256, 13, 13] 884,992
ReLU-10 [32, 256, 13, 13] 0
Conv2d-11 [32, 256, 13, 13] 590,080
ReLU-12 [32, 256, 13, 13] 0
MaxPool2d-13 [32, 256, 6, 6] 0
AdaptiveAvgPool2d-14 [32, 256, 6, 6] 0
Dropout-15 [32, 9216] 0
Linear-16 [32, 4096] 37,752,832
ReLU-17 [32, 4096] 0
Dropout-18 [32, 4096] 0
Linear-19 [32, 4096] 16,781,312
ReLU-20 [32, 4096] 0
Linear-21 [32, 1000] 4,097,000
================================================================
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 18.38
Forward/backward pass size (MB): 268.12
Params size (MB): 233.08
Estimated Total Size (MB): 519.58
----------------------------------------------------------------
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace)
(3): Dropout(p=0.5)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(구)torchsummary
) 패키지는 (특정1 입력 2쉐이프에 대해) Keras와 유사한 출력을 생성합니다.
from torchinfo import summary
model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─Conv2d (conv1): 1-1 [5, 10, 24, 24] 260
├─Conv2d (conv2): 1-2 [5, 20, 8, 8] 5,020
├─Dropout2d (conv2_drop): 1-3 [5, 20, 8, 8] --
├─Linear (fc1): 1-4 [5, 50] 16,050
├─Linear (fc2): 1-5 [5, 10] 510
==========================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 7.69
==========================================================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05
==========================================================================================
주의:
-
Torchinfo는 다음과 같은 정보를 제공합니다.
print(your_model)
PyTorch에서, 텐서플로우의 것과 유사하다.model.summary()
... Keras와 달리, PyTorch는 동적 계산 그래프를 가지고 있으며, 여러 호출에 걸쳐 호환되는 입력 형태에 적응할 수 있습니다. 예를 들어 (완전 컨볼루션 네트워크를 위한) 충분히 큰 이미지 크기입니다.
따라서 각 레이어에 고유한 입력/출력 쉐이프 세트를 표시할 수 없습니다.이러한 쉐이프는 입력에 의존하기 때문입니다.또, 상기의 패키지에서는 입력 치수를 지정할 필요가 있습니다.
모델의 가중치와 모수가 표시됩니다(출력 모양은 표시되지 않음).
from torch.nn.modules.module import _addindent
import torch
import numpy as np
def torch_summarize(model, show_weights=True, show_parameters=True):
"""Summarizes torch model by showing trainable parameters and weights."""
tmpstr = model.__class__.__name__ + ' (\n'
for key, module in model._modules.items():
# if it contains layers let call it recursively to get params and weights
if type(module) in [
torch.nn.modules.container.Container,
torch.nn.modules.container.Sequential
]:
modstr = torch_summarize(module)
else:
modstr = module.__repr__()
modstr = _addindent(modstr, 2)
params = sum([np.prod(p.size()) for p in module.parameters()])
weights = tuple([tuple(p.size()) for p in module.parameters()])
tmpstr += ' (' + key + '): ' + modstr
if show_weights:
tmpstr += ', weights={}'.format(weights)
if show_parameters:
tmpstr += ', parameters={}'.format(params)
tmpstr += '\n'
tmpstr = tmpstr + ')'
return tmpstr
# Test
import torchvision.models as models
model = models.alexnet()
print(torch_summarize(model))
# # Output
# AlexNet (
# (features): Sequential (
# (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)), weights=((64, 3, 11, 11), (64,)), parameters=23296
# (1): ReLU (inplace), weights=(), parameters=0
# (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
# (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), weights=((192, 64, 5, 5), (192,)), parameters=307392
# (4): ReLU (inplace), weights=(), parameters=0
# (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
# (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((384, 192, 3, 3), (384,)), parameters=663936
# (7): ReLU (inplace), weights=(), parameters=0
# (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 384, 3, 3), (256,)), parameters=884992
# (9): ReLU (inplace), weights=(), parameters=0
# (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 256, 3, 3), (256,)), parameters=590080
# (11): ReLU (inplace), weights=(), parameters=0
# (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
# ), weights=((64, 3, 11, 11), (64,), (192, 64, 5, 5), (192,), (384, 192, 3, 3), (384,), (256, 384, 3, 3), (256,), (256, 256, 3, 3), (256,)), parameters=2469696
# (classifier): Sequential (
# (0): Dropout (p = 0.5), weights=(), parameters=0
# (1): Linear (9216 -> 4096), weights=((4096, 9216), (4096,)), parameters=37752832
# (2): ReLU (inplace), weights=(), parameters=0
# (3): Dropout (p = 0.5), weights=(), parameters=0
# (4): Linear (4096 -> 4096), weights=((4096, 4096), (4096,)), parameters=16781312
# (5): ReLU (inplace), weights=(), parameters=0
# (6): Linear (4096 -> 1000), weights=((1000, 4096), (1000,)), parameters=4097000
# ), weights=((4096, 9216), (4096,), (4096, 4096), (4096,), (1000, 4096), (1000,)), parameters=58631144
# )
편집: isaykatsman에는 pytorch PR을 추가하여model.summary()
https://github.com/pytorch/pytorch/pull/3043/files과 같은 것입니다.
가장 기억하기 쉬운 것(Keras만큼 예쁘지 않음):
print(model)
이것도 동작합니다.
repr(model)
파라미터의 수만을 원하는 경우:
sum([param.nelement() for param in model.parameters()])
From: 모델과 유사한 Pytorch 기능이 있습니까?summary()를 keras로 지정합니다.PyTorch.org)
사용할 수 있습니다.
from torchsummary import summary
디바이스를 지정할 수 있습니다.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
네트워크를 생성할 수 있습니다.MNIST 데이터 세트를 사용하는 경우 다음 명령어가 작동하여 요약을 표시합니다.
model = Network().to(device)
summary(model,(1,28,28))
토치 요약을 사용한 모델 요약과 같은 Keras:
from torchsummary import summary
summary(model, input_size=(3, 224, 224))
저는 이 간단한 조각이 더 좋아요.
net = model
modules = [module for module in net.modules()]
params = [param.shape for param in net.parameters()]
# Print Model Summary
print(modules[0])
total_params=0
for i in range(1,len(modules)):
j = 2*i
param = (params[j-2][1]*params[j-2][0])+params[j-1][0]
total_params += param
print("Layer",i,"->\t",end="")
print("Weights:", params[j-2][0],"x",params[j-2][1],
"\tBias: ",params[j-1][0], "\tParameters: ", param)
print("\nTotal Params: ", total_params)
필요한 것은 모두 인쇄합니다.
Net(
(hLayer1): Linear(in_features=1024, out_features=256, bias=True)
(hLayer2): Linear(in_features=256, out_features=128, bias=True)
(hLayer3): Linear(in_features=128, out_features=64, bias=True)
(outLayer): Linear(in_features=64, out_features=10, bias=True)
)
Layer 1 -> Weights: 256 x 1024 Bias: 256 Parameters: 262400
Layer 2 -> Weights: 128 x 256 Bias: 128 Parameters: 32896
Layer 3 -> Weights: 64 x 128 Bias: 64 Parameters: 8256
Layer 4 -> Weights: 10 x 64 Bias: 10 Parameters: 650
Total Parameters: 304202
복잡한 모형 또는 모형의 보다 심층적인 통계량
torchstat 설치
pip install torchstat
통계 가져오기
from torchstat import stat
import torchvision.models as models
model = models.vgg19()
stat(model, (3, 224, 224))
출력 -
module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
0 features.0 3 224 224 64 224 224 1792.0 12.25 173,408,256.0 89,915,392.0 609280.0 12845056.0 21.56% 13454336.0
1 features.1 64 224 224 64 224 224 0.0 12.25 3,211,264.0 3,211,264.0 12845056.0 12845056.0 0.92% 25690112.0
2 features.2 64 224 224 64 224 224 36928.0 12.25 3,699,376,128.0 1,852,899,328.0 12992768.0 12845056.0 4.74% 25837824.0
3 features.3 64 224 224 64 224 224 0.0 12.25 3,211,264.0 3,211,264.0 12845056.0 12845056.0 0.92% 25690112.0
4 features.4 64 224 224 64 112 112 0.0 3.06 2,408,448.0 3,211,264.0 12845056.0 3211264.0 1.22% 16056320.0
5 features.5 64 112 112 128 112 112 73856.0 6.12 1,849,688,064.0 926,449,664.0 3506688.0 6422528.0 4.71% 9929216.0
6 features.6 128 112 112 128 112 112 0.0 6.12 1,605,632.0 1,605,632.0 6422528.0 6422528.0 0.94% 12845056.0
7 features.7 128 112 112 128 112 112 147584.0 6.12 3,699,376,128.0 1,851,293,696.0 7012864.0 6422528.0 4.36% 13435392.0
8 features.8 128 112 112 128 112 112 0.0 6.12 1,605,632.0 1,605,632.0 6422528.0 6422528.0 0.91% 12845056.0
9 features.9 128 112 112 128 56 56 0.0 1.53 1,204,224.0 1,605,632.0 6422528.0 1605632.0 1.51% 8028160.0
10 features.10 128 56 56 256 56 56 295168.0 3.06 1,849,688,064.0 925,646,848.0 2786304.0 3211264.0 3.57% 5997568.0
11 features.11 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.90% 6422528.0
12 features.12 256 56 56 256 56 56 590080.0 3.06 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 4.30% 8782848.0
13 features.13 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.90% 6422528.0
14 features.14 256 56 56 256 56 56 590080.0 3.06 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 4.38% 8782848.0
15 features.15 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.94% 6422528.0
16 features.16 256 56 56 256 56 56 590080.0 3.06 3,699,376,128.0 1,850,490,880.0 5571584.0 3211264.0 4.33% 8782848.0
17 features.17 256 56 56 256 56 56 0.0 3.06 802,816.0 802,816.0 3211264.0 3211264.0 0.90% 6422528.0
18 features.18 256 56 56 256 28 28 0.0 0.77 602,112.0 802,816.0 3211264.0 802816.0 1.44% 4014080.0
19 features.19 256 28 28 512 28 28 1180160.0 1.53 1,849,688,064.0 925,245,440.0 5523456.0 1605632.0 3.60% 7129088.0
20 features.20 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.92% 3211264.0
21 features.21 512 28 28 512 28 28 2359808.0 1.53 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 4.45% 12650496.0
22 features.22 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.94% 3211264.0
23 features.23 512 28 28 512 28 28 2359808.0 1.53 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 4.39% 12650496.0
24 features.24 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.90% 3211264.0
25 features.25 512 28 28 512 28 28 2359808.0 1.53 3,699,376,128.0 1,850,089,472.0 11044864.0 1605632.0 4.34% 12650496.0
26 features.26 512 28 28 512 28 28 0.0 1.53 401,408.0 401,408.0 1605632.0 1605632.0 0.90% 3211264.0
27 features.27 512 28 28 512 14 14 0.0 0.38 301,056.0 401,408.0 1605632.0 401408.0 0.96% 2007040.0
28 features.28 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 0.99% 10242048.0
29 features.29 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
30 features.30 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 0.11% 10242048.0
31 features.31 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
32 features.32 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 0.11% 10242048.0
33 features.33 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
34 features.34 512 14 14 512 14 14 2359808.0 0.38 924,844,032.0 462,522,368.0 9840640.0 401408.0 0.11% 10242048.0
35 features.35 512 14 14 512 14 14 0.0 0.38 100,352.0 100,352.0 401408.0 401408.0 0.00% 802816.0
36 features.36 512 14 14 512 7 7 0.0 0.10 75,264.0 100,352.0 401408.0 100352.0 0.01% 501760.0
37 avgpool 512 7 7 512 7 7 0.0 0.10 0.0 0.0 0.0 0.0 0.49% 0.0
38 classifier.0 25088 4096 102764544.0 0.02 205,516,800.0 102,760,448.0 411158528.0 16384.0 11.27% 411174912.0
39 classifier.1 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
40 classifier.2 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.01% 0.0
41 classifier.3 4096 4096 16781312.0 0.02 33,550,336.0 16,777,216.0 67141632.0 16384.0 1.08% 67158016.0
42 classifier.4 4096 4096 0.0 0.02 4,096.0 4,096.0 16384.0 16384.0 0.00% 32768.0
43 classifier.5 4096 4096 0.0 0.02 0.0 0.0 0.0 0.0 0.00% 0.0
44 classifier.6 4096 1000 4097000.0 0.00 8,191,000.0 4,096,000.0 16404384.0 4000.0 0.93% 16408384.0
total 143667240.0 119.34 39,283,567,128.0 19,667,896,320.0 16404384.0 4000.0 100.00% 825282624.0
============================================================================================================================================================
Total params: 143,667,240
------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 119.34 MB
Total MAdd: 39.28 GMAdd
Total Flops: 19.67 GFlops
Total MemR+W: 787.05 MB
모델 클래스에 대한 개체를 정의한 후 간단히 모델 인쇄
class RNN(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward():
...
model = RNN(input_dim, embedding_dim, hidden_dim, output_dim)
print(model)
summary(my_model, (3, 224, 224), device = 'cpu')
문제가 해결됩니다.
심플하게 하는 것이 좋다
from torchvision.models import resnet18
from torchscope import scope
model = resnet18()
scope(model, input_size=(3, 224, 224))
언급URL : https://stackoverflow.com/questions/42480111/how-do-i-print-the-model-summary-in-pytorch
'programing' 카테고리의 다른 글
"가격" 열에 가장 적합한 MySQL 유형은 무엇입니까? (0) | 2022.11.02 |
---|---|
쿼리 문자열 내에서 어레이를 전달하려면 어떻게 해야 합니까? (0) | 2022.11.02 |
PHP DateInterval에서 총 시간(초)을 계산합니다. (0) | 2022.11.02 |
현재 사용자가 MySQL/MariaDB에 대한 특정 권한을 가지고 있는지 확인하려면 어떻게 해야 합니까? (0) | 2022.11.02 |
phpMyAdmin에서 MySQL 쿼리를 기반으로 CSV 생성 (0) | 2022.11.02 |