本文最后更新于:2024年1月23日 中午
PyTorch参数自动命名规则
当我们使用
1 2
| for name, params in model.named_parameters(): print(f"Parameter Name: {name}, Parameter Shape: {params.shape}")
|
时可以看到模型的参数以及参数的名字,PyTorch内部实际上有一套命名的规则。
我们创建一个简单的CNN模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
| class Block(nn.Module): def __init__(self, channels): super().__init__() self.conv1 = nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=7, groups=channels, padding=3 ) self.proj1 = nn.Linear( in_features=channels, out_features=channels*4 ) self.proj2 = nn.Linear( in_features=channels*4, out_features=channels ) self.act = nn.GELU() self.norm = nn.BatchNorm2d(num_features=channels) def forward(self, x): b, c, h, w = x.shape out = self.conv1(x) out = self.norm(out) out = out.permute(0, 2, 3, 1) out = self.proj1(out) out = self.act(out) out = self.proj2(out) out = out.permute(0, 3, 1, 2) return out + x class model(nn.Module): def __init__(self, width, depth): super().__init__() self.num_stages = len(depth) self.sampler = nn.ModuleList() self.stages = nn.ModuleList() stem = nn.Sequential( nn.Conv2d( in_channels=3, out_channels=width[0], kernel_size=4, stride=4 ), nn.BatchNorm2d(num_features=width[0]) ) self.sampler.append(stem) for i in range(self.num_stages-1): m = nn.Sequential( nn.BatchNorm2d(width[i]), nn.Conv2d( in_channels=width[i], out_channels=width[i+1], kernel_size=2, stride=2 ) ) self.sampler.append(m) for i in range(self.num_stages): m = nn.Sequential() for j in range(depth[i]): m.append(Block(channels=width[i])) self.stages.append(m) def forward(self, x): for i in range(self.num_stages): x = self.sampler[i](x) x = self.stages[i](x) return x
|
我们创建一个两个stage,每个stage两个block的模型
1 2 3 4
| m = model( width=[64, 128], depth=[2, 2] )
|
接下来使用named_parameters
来打印参数,得到
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| Parameter Name: sampler.0.0.weight Parameter Shape: torch.Size([64, 3, 4, 4]) Parameter Name: sampler.0.0.bias Parameter Shape: torch.Size([64]) Parameter Name: sampler.0.1.weight Parameter Shape: torch.Size([64]) Parameter Name: sampler.0.1.bias Parameter Shape: torch.Size([64]) Parameter Name: sampler.1.0.weight Parameter Shape: torch.Size([64]) Parameter Name: sampler.1.0.bias Parameter Shape: torch.Size([64]) Parameter Name: sampler.1.1.weight Parameter Shape: torch.Size([128, 64, 2, 2]) Parameter Name: sampler.1.1.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.0.0.conv1.weight Parameter Shape: torch.Size([64, 1, 7, 7]) Parameter Name: stages.0.0.conv1.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.0.0.proj1.weight Parameter Shape: torch.Size([256, 64]) Parameter Name: stages.0.0.proj1.bias Parameter Shape: torch.Size([256]) Parameter Name: stages.0.0.proj2.weight Parameter Shape: torch.Size([64, 256]) Parameter Name: stages.0.0.proj2.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.0.0.norm.weight Parameter Shape: torch.Size([64]) Parameter Name: stages.0.0.norm.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.0.1.conv1.weight Parameter Shape: torch.Size([64, 1, 7, 7]) Parameter Name: stages.0.1.conv1.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.0.1.proj1.weight Parameter Shape: torch.Size([256, 64]) Parameter Name: stages.0.1.proj1.bias Parameter Shape: torch.Size([256]) Parameter Name: stages.0.1.proj2.weight Parameter Shape: torch.Size([64, 256]) Parameter Name: stages.0.1.proj2.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.0.1.norm.weight Parameter Shape: torch.Size([64]) Parameter Name: stages.0.1.norm.bias Parameter Shape: torch.Size([64]) Parameter Name: stages.1.0.conv1.weight Parameter Shape: torch.Size([128, 1, 7, 7]) Parameter Name: stages.1.0.conv1.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.1.0.proj1.weight Parameter Shape: torch.Size([512, 128]) Parameter Name: stages.1.0.proj1.bias Parameter Shape: torch.Size([512]) Parameter Name: stages.1.0.proj2.weight Parameter Shape: torch.Size([128, 512]) Parameter Name: stages.1.0.proj2.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.1.0.norm.weight Parameter Shape: torch.Size([128]) Parameter Name: stages.1.0.norm.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.1.1.conv1.weight Parameter Shape: torch.Size([128, 1, 7, 7]) Parameter Name: stages.1.1.conv1.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.1.1.proj1.weight Parameter Shape: torch.Size([512, 128]) Parameter Name: stages.1.1.proj1.bias Parameter Shape: torch.Size([512]) Parameter Name: stages.1.1.proj2.weight Parameter Shape: torch.Size([128, 512]) Parameter Name: stages.1.1.proj2.bias Parameter Shape: torch.Size([128]) Parameter Name: stages.1.1.norm.weight Parameter Shape: torch.Size([128]) Parameter Name: stages.1.1.norm.bias Parameter Shape: torch.Size([128])
|
可以看到,PyTorch首先会读取model attribute的名字,由于self.sampler
是他第一个遇到的模型参数,所以参数列表里先是sampler
开头。
这里使用了循环来构建,将模型用到的模块装进了nn.Sequential
和nn.ModuleList
里面,由于循环构建模块没有类似self.xxx
的直接命名,PyTorch采用了类似数组的方式命名,中间用.
隔开。
我们一共两个stage,在sampler中,每个stage有两个Module,所以可以看到类似0.0
, 0.1
这样的命名方式,其代表着第0个stage的第0个模块,和第0个stage的第1个模块,和stages[0][0]
类似,但在PyTorch命名中都使用.
隔开。再往后面,是具体到参数的weight
和bias
,sampler的satge.1的第一个Module是nn.Conv2d()
,所以可以看到有一个63x3x4x4
大小的weight
,和一个64
大小的bias
。第二个Module是nn.BatchNorm2d()
,所以可以看到0.1
后面有一个64
的weight
和bias
。
接下来遇到的参数是self.stages
,这里面使用循环不断创建Block
,前面的命名和sampler一样是stages.x.x
,后面则是读取到了Block
里面的参数名字。在Block
中,分别有self.conv1
, self.proj1
, self.proj2
, self.norm
。激活函数没有参数所以不在里面。同样,具体到每个Module
都有一个weight
和bias
。