MetricLogger:大厂都在用的指标记录器

本文最后更新于:2024年1月22日 下午

MetricLogger:大厂都在用的指标记录器

MetricLogger是现在比较流行的一个用来记录各种metric的类,它实际上最开始来源于DeiT项目,而DeiT项目又是从torchvision-classification-reference-utils.py里面抄过来的,所以总的来说是PyTorch提供的一个轮子。

1 前置模块:SmoothedValue

SmoothedValue提供了一个滑动窗口来自动计算窗口内的值,并且能够以特定的格式来进行打印。首先是其初始化部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""

def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
# 记录值的和
self.count = 0
# 目前看到了多少个值
self.fmt = fmt

fmt是字符串输出的格式,这里默认是"{median:.4f} ({global_avg:.4f})",也就是打印四位小数的中位数和四位小数的全局平均数。

deque则是一个数据结构–队列,一般可以用from collections import deque来导入使用,这个数据结构类似与列表,但允许在左边和右边append新的值。当设定maxlen之后,我们得到的是一个有界的队列,当元素达到最大队列长的时候,在一个方向上append新元素将会让队列在对面移除最旧的元素。

1
2
3
4
5
6
7
8
9
my_deque = deque(maxlen=3)
# 创建一个最大长度为3的队列
my_deque.append(1)
my_deque.append(2)
my_deque.append(3)
my_deque.append(4)
# 此时会删除1, deque内包含[2,3,4]
my_deque.append(5)
# 此时会删除2, deque内包含[3,4,5]

所以当数据流源源不断被append进deque的时候,deque就表现为一个具有maxlen大小的滑动窗口在数据流上滑动。

self.total用来记录所有值的和,self.count是目前看到的值的数量的总数。

接下来是它的第一个方法def update(self, value, n=1)

1
2
3
4
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n

首先,它将传进来的valueappend到队列里,并且count自身加n,total自身加value×nvalue\times n,参数n能够让其在某个值更新的时候给它赋权重。例如在mini batch训练中

1
value.update(acc1.item(), n=batch_size)

这样可以得到这个batch总和的acc1值(去除平均),方便后面算全局平均。

然后是def synchronize_between_processes(self)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
# 判断是不是在进行分布式训练,如果不是就直接返回
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
# 装入tensor来使用PyTorch分布式框架的同步功能
dist.barrier()
# 等待整组的所有进程都进入这个函数
dist.all_reduce(t)
# all reduce这个值,让所有进程的这个值都得到最终结果
t = t.tolist()
# 转回list
self.count = int(t[0])
self.total = t[1]
# 更新self里的count和total

这个方法就用来在分布式训练时同步不同进程间的值,首先其判断了是不是在进行分布式训练,如果是,先将值装入tensor,以此来使用PyTorch的分布式框架,然后调用barrier来block进程,直到所有进程都进入这里,然后all_reduce归纳整理所有进程间的这个tensor,这样就在不同进程间同步了counttotal。最后将tensor转回list,再给object的counttotal更新。

检查是否启用分布式的代码是

1
2
3
4
5
6
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True

其实就是调用dist内部的is_availableis_initialized来判断。

接下来是中位数计算

1
2
3
4
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()

property装饰器将这个方法转换为属性,这样调用的时候可以把median直接当属性调用。方法内部,首先将队列转list再转tensor,然后使用tensor自带的中位数计算方法median()得到中位数,然后用item()取出tensor的值。这里也能体现出smoothed value是怎么回事:它以滑动窗口的形式,在窗口内计算中位数。

然后是其它一些统计特性的计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@property
def avg(self):
# 平均值计算
d = torch.tensor(list(self.deque), dtype=torch.float32)
# 取出窗口(队列)的值装入tensor
return d.mean().item()
# 计算并返回

@property
def global_avg(self):
# 总体平均数计算
return self.total / self.count

@property
def max(self):
# 最大值
return max(self.deque)
# 直接返回窗口内最大值

@property
def value(self):
# 取值操作
return self.deque[-1]
# 取出队列最后一个元素值,即最新的值

最后是重载的def __str__(self)

1
2
3
4
5
6
7
8
9
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value
)

调用python字符串方法format()来格式化字符串输出,这里因为使用了property装饰器,所以用的是self.median这样attribute like的调用。注意到,format()方法不会填充不存在的占位符,也就是说,默认的占位符是medianglobal_avg,那么即使给format传了avg, max, value,也不会在返回的字符串中出现。这实际上给了我们控制返回的能力,我们在创建SmoothedValue的时候就可以事先传入fmt,例如我们要返回平均数,就可以fmt='{avg:.4f}',这样返回的时候就会只返回平均数。

以上就是SmoothedValue类的解析,其贡献在于滑动窗口计算统计值以及分布式训练的同步。

2 MetricLogger

接下来正式讲MetricLogger,这个类不仅可以用来追踪所有的指标,自带平滑功能,而且还内置迭代器,能够完成计时等等功能。

首先是用到的几个模块

1
2
3
4
from collections import defaultdict
import torch
import time
import datetime

首先是初始化函数

1
2
3
4
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter

delimiter意思是分隔符,这是配合字符串的join()方法使用的,默认为\t,具体在后面可以看到。meters则是用来装指标的,它是一个defaultdict,而defaultdict将默认字典内的value都是SmoothedValue对象。也就是说往meters里面创建key-value对的时候,value将是一个SmoothedValue对象。

接下来是第一个方法,def update(self, **kwargs)

1
2
3
4
5
6
7
8
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)

这是一个多参数传参的函数,函数内部将遍历参数字典,这里调用items()将其转换为可以用key, value同时迭代的对象(简写为k, v)。如果发现valueNone,那就跳过接下来的代码,如果发现value是tensor object,那就调用item()方法将tensor内部的值取出来。最后检查一下取出来的值是不是float或者int,然后将meters内部对应的key更新其value。注意这里的update方法实际上调用的是SmoothedValue里面的update方法。一个例子如下

meters(defaultdict){key=lossvalue=SmoothedValue{dequetotalcountformat\mathrm{meters(defaultdict)} \begin{cases} \mathrm{key}=\mathrm{loss}\\ {\mathrm{value}=\mathrm{SmoothedValue}\begin{cases}\mathrm{deque}\\ \mathrm{total}\\ \mathrm{count}\\ \mathrm{format}\end{cases}} \end{cases}

接下来是重载的def __getattr__(self, attr)

1
2
3
4
5
6
7
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))

首先检查了查询的attr在不在meters这个字典里,如果在就返回meters里面对应的value,然后检查在不在object本身的字典里,即是不是object本身的attribute,如果是就返回。如果两者都不是,那么就raise一个error:这个物体里面没有这个attribute。这个重载的主要作用就是将meters这个字典当作object的attribute来用。

然后是重载的def __str__(self)

1
2
3
4
5
6
7
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)

首先其创建了一个空的列表loss_str,然后同样以key-value的形式遍历meters字典,接着按照"{}: {}".format(name, str(meter))的格式来生成字符串,例如loss: 0.6132,然后将生成的字符串append进loss_str这个列表中,最后使用join()方法来生成一个用delimiter来分隔的字符串。

join方法配合列表来使用的时候,它会挨个读取列表的内容组合成一个字符串,并在元素间插入一个字符串来分隔。例如

1
2
lst = ['a', 'b', 'c']
ret = ', '.join(lst)

结果为

1
a, b, c

所以这里就在不同的指标之间(例如loss: 1.1, acc1: 42, acc5: 61)插入了\t

然后是进程间同步方法def synchronize_between_processes(self)

1
2
3
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()

读取所有meters字典里面的值(SmoothedValue),并调用它们的synchronize_between_processes()方法来进行同步,具体如何同步看之前SmoothedValue。

然后是添加指标def add_meter(self, name, meter)

1
2
def add_meter(self, name, meter):
self.meters[name] = meter

在字典中创建一个keynamevaluemeterkey-value对,一般调用的时候meter传入SmoothedValue。例如

1
2
3
4
metric_logger.add_meter(
'lr',
SmoothedValue(window_size=1, fmt='{value:.6f}')
)

虽然直接对字典写入key-value对,在没有这个key的时候也会创建key,但是这样创建的key默认的SmoothedValue并不能调整其初始化函数的参数,所以这里提供了一个调整window_sizefmt的方法。

最后是def log_every(self, iterable, print_freq, header=None),这个方法是MetricLogger的精髓

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def log_every(self, iterable, print_freq, header=None):
i = 0
# 计数器
if not header:
header = ''
# 如果没有传header参数,则默认为空字符串
start_time = time.time()
end = time.time()
# 首先记录现在的时间点
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
# 创建两个SmoothedValue类来记录迭代时间和数据读取时间
# 仅返回四位小数的平均值
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
# 创建空格分隔符,用来在打印的时候对齐
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]', # 会被替换为目前所在的迭代次数与总迭代次数
'eta: {eta}', # 剩余时间
'{meters}', # 各种指标
'time: {time}', # 迭代时间
'data: {data}' # 数据读取时间
]
# 将要打印的信息放入列表
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
# 如果有cuda,那么列表多一项最大内存占用
log_msg = self.delimiter.join(log_msg)
# 使用分隔符分隔这些信息
MB = 1024.0 * 1024.0
# 计算一个常数

接下来进入一个迭代

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
for obj in iterable:
data_time.update(time.time() - end)
# 更新obj从iterable拿出来的时间
yield obj
# 生成obj返回到调用的迭代器那里
iter_time.update(time.time() - end)
# 更新外循环处理obj用的时间
if i % print_freq == 0 or i == len(iterable) - 1:
# 如果在打印频率上,或者是最后一个循环
eta_seconds = iter_time.global_avg * (len(iterable) - i)
# 估算剩余时间
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
# 格式化剩余时间,将秒转为"h:m:s"
if torch.cuda.is_available():
# 如果在用cuda训练
print(
log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self), # 重载过str方法
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB)
)
else:
# 不使用cuda就不打印显存占用
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
# 计数器+1
end = time.time()
# 更新一个迭代完成后的时间点

total_time = time.time() - start_time
# 总花费时间是现在的时间减去最开始记录的时间
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
# 秒转化为"h:m:s"
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
# 格式化输出总时间,以及平均每个obj消耗的时间

此时在外部我们只需要调用log_every方法就可以像迭代data_loader一样迭代数据,并且在更新频率上打印各种指标,还可以追踪各个部分消耗的时间。

接下来我们写一个test来看看这个东西的表现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import utils
import math
import time
import torch

logger = utils.MetricLogger(
delimiter="\t"
)

loss = 25
acc = 1

data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
loader = torch.utils.data.DataLoader(
data,
batch_size=2,
shuffle=True
)

logger.add_meter(
name='loss',
meter=utils.SmoothedValue(fmt="(avg){avg:.4f} (global_avg: {global_avg:.4f})")
)
logger.add_meter(
name='acc',
meter=utils.SmoothedValue(fmt="(avg){avg:.4f} (global_avg: {global_avg:.4f})")
)

for index, obj in enumerate(logger.log_every(
loader,
print_freq=1,
header='Test'
)):
print(obj)
time.sleep(2)
loss = loss*math.exp(-1*(index+1))
acc = acc*math.log(index+2)

logger.update(loss=loss)
logger.update(acc=acc)

print('done')

输出为

1
2
3
4
5
6
7
8
9
10
11
12
tensor([2, 5])
Test [0/5] eta: 0:00:10 loss: (avg)9.1970 (global_avg: 9.1970) acc: (avg)0.6931 (global_avg: 0.6931) time: 2.0060 data: 0.0005
tensor([8, 1])
Test [1/5] eta: 0:00:08 loss: (avg)5.2208 (global_avg: 5.2208) acc: (avg)0.7273 (global_avg: 0.7273) time: 2.0064 data: 0.0007
tensor([3, 4])
Test [2/5] eta: 0:00:06 loss: (avg)3.5012 (global_avg: 3.5012) acc: (avg)0.8368 (global_avg: 0.8368) time: 2.0060 data: 0.0006
tensor([6, 9])
Test [3/5] eta: 0:00:04 loss: (avg)2.6262 (global_avg: 2.6262) acc: (avg)1.0523 (global_avg: 1.0523) time: 2.0061 data: 0.0006
tensor([10, 7])
Test [4/5] eta: 0:00:02 loss: (avg)2.1010 (global_avg: 2.1010) acc: (avg)1.4507 (global_avg: 1.4507) time: 2.0062 data: 0.0006
Test Total time: 0:00:10 (2.0081 s / it)
done

这里窗口太大看不出来global_avg和avg的区别,如果print_freq改成2,那么

1
2
3
4
5
6
7
8
9
10
tensor([6, 1])
Test [0/5] eta: 0:00:10 loss: (avg)9.1970 (global_avg: 9.1970) acc: (avg)0.6931 (global_avg: 0.6931) time: 2.0059 data: 0.0004
tensor([ 3, 10])
tensor([8, 2])
Test [2/5] eta: 0:00:06 loss: (avg)3.5012 (global_avg: 3.5012) acc: (avg)0.8368 (global_avg: 0.8368) time: 2.0064 data: 0.0005
tensor([5, 9])
tensor([7, 4])
Test [4/5] eta: 0:00:02 loss: (avg)2.1010 (global_avg: 2.1010) acc: (avg)1.4507 (global_avg: 1.4507) time: 2.0065 data: 0.0006
Test Total time: 0:00:10 (2.0075 s / it)
done

除了首次和最后一次,就是每隔两个batch打印一次。

以上便是MetricLogger的详解。


MetricLogger:大厂都在用的指标记录器
https://jesseprince.github.io/2024/01/22/pytorch/metriclogger/
作者
林正
发布于
2024年1月22日
许可协议