PyTorch参数自动命名规则

本文最后更新于:2024年1月23日 中午

PyTorch参数自动命名规则

当我们使用

1
2
for name, params in model.named_parameters():
print(f"Parameter Name: {name}, Parameter Shape: {params.shape}")

时可以看到模型的参数以及参数的名字,PyTorch内部实际上有一套命名的规则。

我们创建一个简单的CNN模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Block(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=7,
groups=channels,
padding=3
)
self.proj1 = nn.Linear(
in_features=channels,
out_features=channels*4
)
self.proj2 = nn.Linear(
in_features=channels*4,
out_features=channels
)
self.act = nn.GELU()
self.norm = nn.BatchNorm2d(num_features=channels)

def forward(self, x):
b, c, h, w = x.shape
out = self.conv1(x)
out = self.norm(out)
out = out.permute(0, 2, 3, 1)
# (N, C, H, W) -> (N, H, W, C)
out = self.proj1(out)
out = self.act(out)
out = self.proj2(out)
out = out.permute(0, 3, 1, 2)
# (N, H, W, C) -> (N, C, H, W)

return out + x

class model(nn.Module):
def __init__(self, width, depth):
super().__init__()
self.num_stages = len(depth)
self.sampler = nn.ModuleList()
self.stages = nn.ModuleList()
stem = nn.Sequential(
nn.Conv2d(
in_channels=3,
out_channels=width[0],
kernel_size=4,
stride=4
),
nn.BatchNorm2d(num_features=width[0])
)
self.sampler.append(stem)
for i in range(self.num_stages-1):
m = nn.Sequential(
nn.BatchNorm2d(width[i]),
nn.Conv2d(
in_channels=width[i],
out_channels=width[i+1],
kernel_size=2,
stride=2
)
)
self.sampler.append(m)
for i in range(self.num_stages):
m = nn.Sequential()
for j in range(depth[i]):
m.append(Block(channels=width[i]))

self.stages.append(m)

def forward(self, x):
for i in range(self.num_stages):
x = self.sampler[i](x)
x = self.stages[i](x)

return x

我们创建一个两个stage,每个stage两个block的模型

1
2
3
4
m = model(
width=[64, 128],
depth=[2, 2]
)

接下来使用named_parameters来打印参数,得到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
Parameter Name: sampler.0.0.weight       Parameter Shape: torch.Size([64, 3, 4, 4])
Parameter Name: sampler.0.0.bias Parameter Shape: torch.Size([64])
Parameter Name: sampler.0.1.weight Parameter Shape: torch.Size([64])
Parameter Name: sampler.0.1.bias Parameter Shape: torch.Size([64])
Parameter Name: sampler.1.0.weight Parameter Shape: torch.Size([64])
Parameter Name: sampler.1.0.bias Parameter Shape: torch.Size([64])
Parameter Name: sampler.1.1.weight Parameter Shape: torch.Size([128, 64, 2, 2])
Parameter Name: sampler.1.1.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.0.0.conv1.weight Parameter Shape: torch.Size([64, 1, 7, 7])
Parameter Name: stages.0.0.conv1.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.0.0.proj1.weight Parameter Shape: torch.Size([256, 64])
Parameter Name: stages.0.0.proj1.bias Parameter Shape: torch.Size([256])
Parameter Name: stages.0.0.proj2.weight Parameter Shape: torch.Size([64, 256])
Parameter Name: stages.0.0.proj2.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.0.0.norm.weight Parameter Shape: torch.Size([64])
Parameter Name: stages.0.0.norm.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.0.1.conv1.weight Parameter Shape: torch.Size([64, 1, 7, 7])
Parameter Name: stages.0.1.conv1.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.0.1.proj1.weight Parameter Shape: torch.Size([256, 64])
Parameter Name: stages.0.1.proj1.bias Parameter Shape: torch.Size([256])
Parameter Name: stages.0.1.proj2.weight Parameter Shape: torch.Size([64, 256])
Parameter Name: stages.0.1.proj2.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.0.1.norm.weight Parameter Shape: torch.Size([64])
Parameter Name: stages.0.1.norm.bias Parameter Shape: torch.Size([64])
Parameter Name: stages.1.0.conv1.weight Parameter Shape: torch.Size([128, 1, 7, 7])
Parameter Name: stages.1.0.conv1.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.1.0.proj1.weight Parameter Shape: torch.Size([512, 128])
Parameter Name: stages.1.0.proj1.bias Parameter Shape: torch.Size([512])
Parameter Name: stages.1.0.proj2.weight Parameter Shape: torch.Size([128, 512])
Parameter Name: stages.1.0.proj2.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.1.0.norm.weight Parameter Shape: torch.Size([128])
Parameter Name: stages.1.0.norm.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.1.1.conv1.weight Parameter Shape: torch.Size([128, 1, 7, 7])
Parameter Name: stages.1.1.conv1.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.1.1.proj1.weight Parameter Shape: torch.Size([512, 128])
Parameter Name: stages.1.1.proj1.bias Parameter Shape: torch.Size([512])
Parameter Name: stages.1.1.proj2.weight Parameter Shape: torch.Size([128, 512])
Parameter Name: stages.1.1.proj2.bias Parameter Shape: torch.Size([128])
Parameter Name: stages.1.1.norm.weight Parameter Shape: torch.Size([128])
Parameter Name: stages.1.1.norm.bias Parameter Shape: torch.Size([128])

可以看到,PyTorch首先会读取model attribute的名字,由于self.sampler是他第一个遇到的模型参数,所以参数列表里先是sampler开头。

这里使用了循环来构建,将模型用到的模块装进了nn.Sequentialnn.ModuleList里面,由于循环构建模块没有类似self.xxx的直接命名,PyTorch采用了类似数组的方式命名,中间用.隔开。

我们一共两个stage,在sampler中,每个stage有两个Module,所以可以看到类似0.0, 0.1这样的命名方式,其代表着第0个stage的第0个模块,和第0个stage的第1个模块,和stages[0][0]类似,但在PyTorch命名中都使用.隔开。再往后面,是具体到参数的weightbias,sampler的satge.1的第一个Module是nn.Conv2d(),所以可以看到有一个63x3x4x4大小的weight,和一个64大小的bias。第二个Module是nn.BatchNorm2d(),所以可以看到0.1后面有一个64weightbias

接下来遇到的参数是self.stages,这里面使用循环不断创建Block,前面的命名和sampler一样是stages.x.x,后面则是读取到了Block里面的参数名字。在Block中,分别有self.conv1, self.proj1, self.proj2, self.norm。激活函数没有参数所以不在里面。同样,具体到每个Module都有一个weightbias


PyTorch参数自动命名规则
https://jesseprince.github.io/2024/01/23/pytorch/namedparam/
作者
林正
发布于
2024年1月23日
许可协议