详解大型项目中的AMP训练

本文最后更新于:2024年1月26日 下午

1 什么是AMP

Automatic Mixed Precision是百度联合英伟达一起推出的一个训练trick,通过在训练过程中部分使用FP16的半精度数据来极大节省内存,同时能加快训练速度。最开始要使用Apex框架来开启AMP训练,但现在PyTorch已经自带AMP相关功能。

2 AMP训练的挑战

AMP训练一般会遇到几个问题,第一是有可能遇到数值下溢和数值上溢,由于FP16能表示的范围要小很多,所以当数据转换为FP16之后可能会发生下溢出(0),和上溢出(inf),为了解决这个问题,我们需要调用torch.cuda.amp.GradScaler()来缩放梯度,让梯度始终在FP16的表示范围内。而缩放的大小torch会自动帮我们决定。在缩放梯度进行反向传播之后,我们需要在优化器step前将其缩放回去,这样才不会与原始设定的学习率的尺度产生冲突。

另外一个问题是在某些算子中(例如BatchNorm),AMP训练会造成不稳定,我们最好是能够随时监控梯度的大小。

3 大厂项目中AMP训练的代码

1
2
3
4
5
6
7
8
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"

def __init__(self):
"""
Loss scaler for AMP training, use torch.cuda.amp.GradScaler
"""
self._scaler = torch.cuda.amp.GradScaler()

这个类就用来进行梯度缩放,并且能够计算梯度的范数。通过计算范数,我们能很方便的知道目前梯度的尺度是怎样一个状态。self._sclaer则是PyTorch提供的GradScaler。

接下来是重载的def __call__()

1
2
3
4
5
6
7
8
9
def __call__(
self,
loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
clip_grad: float = None,
parameters=None,
create_graph: bool = False,
update_grad: bool = True
) -> torch.Tensor:

调用这个NativeScaler需要六个参数,

  • 第一个参数是loss,也就是目前得到的loss tensor,Scaler会scale这个tensor。
  • 第二个参数是optimizer,传入训练用的opimizer即可,unscale梯度的时候需要用到这个参数。第三个参数是
  • 第三个参数是clip_grad,这个参数用来控制是否裁切梯度,如果有传值,那么就会调用torch.nn.utils.clip_grad_norm_()来裁切梯度。
  • 第四个参数是parameters,传入模型的参数
  • 第五个参数是create_graph,用来控制是否创建计算图,这个参数与二阶优化器配合,反向传播时创建计算图可以求二阶导数。
  • 第六个参数是update_grad,即是否更新参数。
  • 返回一个torch.Tensor,里面装的是目前梯度的范数。

接下来是它的实现部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __call__(
self,
loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
clip_grad: float = None,
parameters=None,
create_graph: bool = False,
update_grad: bool = True
) -> torch.Tensor:
self._scaler.scale(loss).backward(create_graph=create_graph)
# scale梯度并用半精度(FP16)来反向传播
if update_grad: # 如果需要更新梯度
if clip_grad is not None: # 如果有裁剪梯度
assert parameters is not None
# 检查模型参数有没有问题
self._scaler.unscale_(optimizer)
# unscale给优化器传入的参数的梯度(在函数内完成参数变动)
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
'''
使用torch提供的`clip_grad_norm`来根据范数裁剪梯度,
最大范数阈值是传入的clip_norm参数,这个函数还会返回
梯度的范数(默认2-范数)
'''
else:
self._scaler.unscale_(optimizer)
# unscale给优化器传入的参数的梯度(在函数内完成参数变动)
norm = get_grad_norm_(parameters)
# 使用自己写的范数计算程序
self._scaler.step(optimizer)
# 检查梯度并让优化器更新参数
self._scaler.update()
# 更新scaler的state dict
else: # 如果不更新梯度
norm = None
# 给一个None值
return norm # 返回梯度的范数用于监控

注意到torch的GradScaler类是需要维护一个state dict的。所以接下来是保存和加载GradScaler的state dict。

1
2
3
4
5
def state_dict(self):
return self._scaler.state_dict()

def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)

第一个方法访问GradScalerstate_dict()方法来得到它的state dict,第二个方法调用load_state_dict来恢复状态字典,这些可以用在恢复训练中。

剩下还有一个自己写的范数计算程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
"""
Calculate the norm of gradient.
:param parameters: model parameters
:param norm_type: type of norm, e.g. l2 norm, infinity norm, etc.
:return: norm of all gradient, stored in torch.Tensor
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# 先判断传入的是不是Tensor,如果是,将其装入一个列表中
parameters = [p for p in parameters if p.grad is not None]
# 对列表中的参数遍历,只取出那些有梯度的参数
norm_type = float(norm_type)
# 转换norm_type,作为torch.norm的参数
if len(parameters) == 0:
return torch.tensor(0.)
# 如果是个空列表,那么没有需要计算的,返回0
device = parameters[0].grad.device
# 获取目前使用的设备
if torch.isinf(torch.Tensor([norm_type])).item():
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
# 如果使用无穷范数,那么找其绝对值的最大值,利用目前设备计算
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
norm_type)
# 如果不是,使用troch.norm来得到范数,范数的类型由参数决定
return total_norm # 返回梯度的范数

接下来在训练engine中,代码也有少许改动

1
2
3
4
5
6
7
8
9
if use_amp:
# 如果使用了AMP
with torch.cuda.amp.autocast():
# 需要添加autocast来进行前向推理
output = model(samples)
loss = criterion(output, targets)
else: # full precision
output = model(samples)
loss = criterion(output, targets)

然后是梯度部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
max_norm: float = None
# 默认一般不启用梯度裁剪
if use_amp:
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
# 是否使用二阶优化器
loss /= update_freq
# 梯度累加缩放梯度
grad_norm = loss_scaler(
loss, # 传入loss
optimizer, # 传入优化器
clip_grad=max_norm, # 是否梯度裁剪,以及裁剪大小
parameters=model.parameters(), # 传入模型参数
create_graph=is_second_order, # 如果是二阶优化器,创建计算图
update_grad=(data_iter_step + 1) % update_freq == 0
# 在梯度累加的更新频率上更新梯度
)
if (data_iter_step + 1) % update_freq == 0:
# 如果在更新频率上
optimizer.zero_grad()
# 清空梯度
if model_ema is not None:
# 如果有使用EMA
model_ema.update(model)
# 用EMA更新模型
else: # full precision
loss /= update_freq
loss.backward()
if (data_iter_step + 1) % update_freq == 0:
optimizer.step()
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)

注意,loss反向传播和优化器step在loss_scaler里面已经完成了,所以这里和全精度相比,不需要在训练engine里面写loss.backward()optimizer.step()


详解大型项目中的AMP训练
https://jesseprince.github.io/2024/01/26/pytorch/amp/
作者
林正
发布于
2024年1月26日
许可协议