- FLOPS: floating point operations per second,每秒浮点运算次数,计算速度,衡量硬件性能的指标 (大写S代表秒)
- FLOPs: floating point operations,浮点运算数,计算量,用来衡量算法/模型的复杂度。
- Params:没有固定的名称,大小写均可,表示模型的参数量,也是用来衡量算法/模型的复杂度。通常我们在论文中见到的是这样:#Params,那个井号是表示 number of 的意思,因此 #Params 的意思就是:参数的数量。
FLOPs与模型时间复杂度、GPU利用率有关,Params与模型空间复杂度、显存占用有关
MAC:Multiply Accumulate,乘加运算。乘积累加运算(英语:Multiply Accumulate, MAC)是在数字信号处理器或一些微处理器中的特殊运算。实现此运算操作的硬件电路单元,被称为“乘数累加器”。这种运算的操作,是将乘法的乘积结果和累加器的值相加,再存入累加器:
a←a+b×c
使用MAC可以将原本需要的两个指令操作减少到一个指令操作,从而提高运算效率。
FLOPs的计算
不考虑激活函数的计算量
卷积层
全连接层
Code
完整参数量统计
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| import torch from torchvision.models import resnet50 import numpy as np
Total_params = 0 Trainable_params = 0 NonTrainable_params = 0
model = resnet50() for param in model.parameters(): mulValue = np.prod(param.size()) Total_params += mulValue if param.requires_grad: Trainable_params += mulValue else: NonTrainable_params += mulValue
print(f'Total params: {Total_params / 1e6}M') print(f'Trainable params: {Trainable_params/ 1e6}M') print(f'Non-trainable params: {NonTrainable_params/ 1e6}M')
|
1 2 3 4 5
| >>> output:
Total params: 25.557032M Trainable params: 25.557032M Non-trainable params: 0.0M
|
简单统计可训练参数量
1 2 3 4 5 6
| import torchvision.models as models
model = models.resnet50(pretrained=False)
Trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Trainable params: {Trainable_params/ 1e6}M')
|
1 2 3
| >>> output:
Trainable params: 25.557032M
|
统计每一层
1 2 3
| model = vgg16() for name, parameters in model.named_parameters(): print(name, ':', np.prod(parameters.size()))
|
1 2 3 4 5 6 7 8
| >>> output:
features.0.weight : 1728 features.0.bias : 64 features.2.weight : 36864 features.2.bias : 64 features.5.weight : 73728 ...
|
使用thop库来获取模型的FLOPs(计算量)和Params(参数量)
1 2 3 4 5 6 7 8 9 10
| import torch from thop import profile from archs.ViT_model import get_vit, ViT_Aes from torchvision.models import resnet50
model = resnet50() input1 = torch.randn(4, 3, 224, 224) flops, params = profile(model, inputs=(input1, )) print('FLOPs = ' + str(flops/1000**3) + 'G') print('Params = ' + str(params/1000**2) + 'M')
|
1 2 3 4
| >>> output:
FLOPs = 16.446058496G Params = 25.557032M
|
Ref:
https://blog.csdn.net/weixin_44966641/article/details/120104600