混合精度训练初探
sixwalter Lv6

自动半精度(混合精度)训练

一. apex 与 amp

apex是英伟达构建的一个pytorch扩展,amp为其中提供混合精度的库

二. fp16的问题

2.1 数据溢出(下溢)

image-20220412183129207

2.2 舍入误差

preview

三. 解决方法

3.1 FP32权重备份

preview

只有更新的时候采用F32

3.2 loss scale损失放大

根据链式法则,可以通过放大loss从而放大梯度来解决舍入误差

3.3 提高算数精度

利用fp16进行乘法和存储,利用fp32来进行加法计算,来减少加法过程中的舍入误差,保证精度不损失

四. 快速使用

1
2
3
4
5
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1",loss_scale=128.0) # 这里是“欧一”,不是“零一”
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

opt_level 优先使用O2,若无法收敛则使用O1

如下是两个pytorch原生支持的apex混合精度和nvidia apex的loss scaler的具体实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch

try:
from apex import amp
has_apex = True
print("successfully import amp")
except ImportError:
amp = None
has_apex = False
print("can not import amp from apex")


class ApexScaler:
state_dict_key = "amp"

def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), clip_grad)
optimizer.step()

def state_dict(self):
if 'state_dict' in amp.__dict__:
return amp.state_dict()

def load_state_dict(self, state_dict):
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)


class NativeScaler:
state_dict_key = "amp_scaler"

def __init__(self):
# GradScaler对象用来自动做梯度缩放
self._scaler = torch.cuda.amp.GradScaler()

def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
self._scaler.step(optimizer)
self._scaler.update()

def state_dict(self):
return self._scaler.state_dict()

def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)

apex + 分布式:

apex ddp默认使用当前设备,torch ddp需要手动指定运行的设备,用法和torch类似,

但需注意

1
model, optimizer = amp.initialize(model, optimizer, flags...)

应在

1
model = apex.parallel.DistributedDataParallel(model)

之前

  • Post title:混合精度训练初探
  • Post author:sixwalter
  • Create time:2023-08-05 11:14:26
  • Post link:https://coelien.github.io/2023/08/05/projects/huawei project/apex/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments