本文最后更新于:2024年1月22日 下午
MetricLogger:大厂都在用的指标记录器
MetricLogger是现在比较流行的一个用来记录各种metric的类,它实际上最开始来源于DeiT项目,而DeiT项目又是从torchvision-classification-reference-utils.py里面抄过来的,所以总的来说是PyTorch提供的一个轮子。
1 前置模块:SmoothedValue
SmoothedValue提供了一个滑动窗口来自动计算窗口内的值,并且能够以特定的格式来进行打印。首先是其初始化部分。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class SmoothedValue (object ): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__ (self, window_size=20 , fmt=None ): if fmt is None : fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt
fmt
是字符串输出的格式,这里默认是"{median:.4f} ({global_avg:.4f})"
,也就是打印四位小数的中位数和四位小数的全局平均数。
deque
则是一个数据结构–队列,一般可以用from collections import deque
来导入使用,这个数据结构类似与列表,但允许在左边和右边append新的值。当设定maxlen
之后,我们得到的是一个有界的队列,当元素达到最大队列长的时候,在一个方向上append新元素将会让队列在对面移除最旧的元素。
1 2 3 4 5 6 7 8 9 my_deque = deque(maxlen=3 ) my_deque.append(1 ) my_deque.append(2 ) my_deque.append(3 ) my_deque.append(4 ) my_deque.append(5 )
所以当数据流源源不断被append进deque的时候,deque就表现为一个具有maxlen大小的滑动窗口在数据流上滑动。
self.total
用来记录所有值的和,self.count
是目前看到的值的数量的总数。
接下来是它的第一个方法def update(self, value, n=1)
1 2 3 4 def update (self, value, n=1 ): self.deque.append(value) self.count += n self.total += value * n
首先,它将传进来的value
append到队列里,并且count
自身加n,total
自身加v a l u e × n value\times n v a l u e × n ,参数n
能够让其在某个值更新的时候给它赋权重。例如在mini batch训练中
1 value.update(acc1.item(), n=batch_size)
这样可以得到这个batch总和的acc1值(去除平均),方便后面算全局平均。
然后是def synchronize_between_processes(self)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def synchronize_between_processes (self ): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda' ) dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int (t[0 ]) self.total = t[1 ]
这个方法就用来在分布式训练时同步不同进程间的值,首先其判断了是不是在进行分布式训练,如果是,先将值装入tensor,以此来使用PyTorch的分布式框架,然后调用barrier
来block进程,直到所有进程都进入这里,然后all_reduce
归纳整理所有进程间的这个tensor,这样就在不同进程间同步了count
和total
。最后将tensor转回list,再给object的count
,total
更新。
检查是否启用分布式的代码是
1 2 3 4 5 6 def is_dist_avail_and_initialized (): if not dist.is_available(): return False if not dist.is_initialized(): return False return True
其实就是调用dist
内部的is_available
和is_initialized
来判断。
接下来是中位数计算
1 2 3 4 @property def median (self ): d = torch.tensor(list (self.deque)) return d.median().item()
property
装饰器将这个方法转换为属性,这样调用的时候可以把median
直接当属性调用。方法内部,首先将队列转list再转tensor,然后使用tensor自带的中位数计算方法median()
得到中位数,然后用item()
取出tensor的值。这里也能体现出smoothed value是怎么回事:它以滑动窗口的形式,在窗口内计算中位数。
然后是其它一些统计特性的计算
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 @property def avg (self ): d = torch.tensor(list (self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg (self ): return self.total / self.count@property def max (self ): return max (self.deque) @property def value (self ): return self.deque[-1 ]
最后是重载的def __str__(self)
1 2 3 4 5 6 7 8 9 def __str__ (self ): return self.fmt.format ( median=self.median, avg=self.avg, global_avg=self.global_avg, max =self.max , value=self.value )
调用python字符串方法format()
来格式化字符串输出,这里因为使用了property
装饰器,所以用的是self.median
这样attribute like的调用。注意到,format()
方法不会填充不存在的占位符,也就是说,默认的占位符是median
和global_avg
,那么即使给format
传了avg
, max
, value
,也不会在返回的字符串中出现。这实际上给了我们控制返回的能力,我们在创建SmoothedValue
的时候就可以事先传入fmt
,例如我们要返回平均数,就可以fmt='{avg:.4f}'
,这样返回的时候就会只返回平均数。
以上就是SmoothedValue类的解析,其贡献在于滑动窗口计算统计值以及分布式训练的同步。
2 MetricLogger
接下来正式讲MetricLogger,这个类不仅可以用来追踪所有的指标,自带平滑功能,而且还内置迭代器,能够完成计时等等功能。
首先是用到的几个模块
1 2 3 4 from collections import defaultdictimport torchimport timeimport datetime
首先是初始化函数
1 2 3 4 class MetricLogger (object ): def __init__ (self, delimiter="\t" ): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter
delimiter
意思是分隔符,这是配合字符串的join()
方法使用的,默认为\t
,具体在后面可以看到。meters
则是用来装指标的,它是一个defaultdict,而defaultdict将默认字典内的value
都是SmoothedValue对象。也就是说往meters里面创建key-value
对的时候,value将是一个SmoothedValue对象。
接下来是第一个方法,def update(self, **kwargs)
1 2 3 4 5 6 7 8 def update (self, **kwargs ): for k, v in kwargs.items(): if v is None : continue if isinstance (v, torch.Tensor): v = v.item() assert isinstance (v, (float , int )) self.meters[k].update(v)
这是一个多参数传参的函数,函数内部将遍历参数字典,这里调用items()
将其转换为可以用key
, value
同时迭代的对象(简写为k
, v
)。如果发现value
是None
,那就跳过接下来的代码,如果发现value
是tensor object,那就调用item()
方法将tensor内部的值取出来。最后检查一下取出来的值是不是float
或者int
,然后将meters
内部对应的key
更新其value
。注意这里的update
方法实际上调用的是SmoothedValue里面的update
方法。一个例子如下
m e t e r s ( d e f a u l t d i c t ) { k e y = l o s s v a l u e = S m o o t h e d V a l u e { d e q u e t o t a l c o u n t f o r m a t \mathrm{meters(defaultdict)}
\begin{cases}
\mathrm{key}=\mathrm{loss}\\
{\mathrm{value}=\mathrm{SmoothedValue}\begin{cases}\mathrm{deque}\\ \mathrm{total}\\ \mathrm{count}\\ \mathrm{format}\end{cases}}
\end{cases}
m e t e r s ( d e f a u l t d i c t ) ⎩ ⎪ ⎪ ⎪ ⎪ ⎪ ⎪ ⎪ ⎨ ⎪ ⎪ ⎪ ⎪ ⎪ ⎪ ⎪ ⎧ k e y = l o s s v a l u e = S m o o t h e d V a l u e ⎩ ⎪ ⎪ ⎪ ⎪ ⎨ ⎪ ⎪ ⎪ ⎪ ⎧ d e q u e t o t a l c o u n t f o r m a t
接下来是重载的def __getattr__(self, attr)
1 2 3 4 5 6 7 def __getattr__ (self, attr ): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'" .format ( type (self).__name__, attr))
首先检查了查询的attr
在不在meters
这个字典里,如果在就返回meters
里面对应的value
,然后检查在不在object本身的字典里,即是不是object本身的attribute,如果是就返回。如果两者都不是,那么就raise一个error:这个物体里面没有这个attribute。这个重载的主要作用就是将meters
这个字典当作object的attribute来用。
然后是重载的def __str__(self)
1 2 3 4 5 6 7 def __str__ (self ): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}" .format (name, str (meter)) ) return self.delimiter.join(loss_str)
首先其创建了一个空的列表loss_str
,然后同样以key-value
的形式遍历meters字典,接着按照"{}: {}".format(name, str(meter))
的格式来生成字符串,例如loss: 0.6132
,然后将生成的字符串append进loss_str
这个列表中,最后使用join()
方法来生成一个用delimiter
来分隔的字符串。
当join
方法配合列表来使用的时候,它会挨个读取列表的内容组合成一个字符串,并在元素间插入一个字符串来分隔。例如
1 2 lst = ['a' , 'b' , 'c' ] ret = ', ' .join(lst)
结果为
所以这里就在不同的指标之间(例如loss: 1.1
, acc1: 42
, acc5: 61
)插入了\t
。
然后是进程间同步方法def synchronize_between_processes(self)
1 2 3 def synchronize_between_processes (self ): for meter in self.meters.values(): meter.synchronize_between_processes()
读取所有meters字典里面的值(SmoothedValue),并调用它们的synchronize_between_processes()
方法来进行同步,具体如何同步看之前SmoothedValue。
然后是添加指标def add_meter(self, name, meter)
1 2 def add_meter (self, name, meter ): self.meters[name] = meter
在字典中创建一个key
为name
,value
为meter
的key-value
对,一般调用的时候meter
传入SmoothedValue
。例如
1 2 3 4 metric_logger.add_meter( 'lr' , SmoothedValue(window_size=1 , fmt='{value:.6f}' ) )
虽然直接对字典写入key-value
对,在没有这个key
的时候也会创建key
,但是这样创建的key
默认的SmoothedValue
并不能调整其初始化函数的参数,所以这里提供了一个调整window_size
和fmt
的方法。
最后是def log_every(self, iterable, print_freq, header=None)
,这个方法是MetricLogger的精髓
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def log_every (self, iterable, print_freq, header=None ): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}' ) data_time = SmoothedValue(fmt='{avg:.4f}' ) space_fmt = ':' + str (len (str (len (iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]' , 'eta: {eta}' , '{meters}' , 'time: {time}' , 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}' ) log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0
接下来进入一个迭代
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len (iterable) - 1 : eta_seconds = iter_time.global_avg * (len (iterable) - i) eta_string = str (datetime.timedelta(seconds=int (eta_seconds))) if torch.cuda.is_available(): print ( log_msg.format ( i, len (iterable), eta=eta_string, meters=str (self), time=str (iter_time), data=str (data_time), memory=torch.cuda.max_memory_allocated() / MB) ) else : print (log_msg.format ( i, len (iterable), eta=eta_string, meters=str (self), time=str (iter_time), data=str (data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str (datetime.timedelta(seconds=int (total_time)))print ('{} Total time: {} ({:.4f} s / it)' .format ( header, total_time_str, total_time / len (iterable)))
此时在外部我们只需要调用log_every
方法就可以像迭代data_loader
一样迭代数据,并且在更新频率上打印各种指标,还可以追踪各个部分消耗的时间。
接下来我们写一个test来看看这个东西的表现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 import utilsimport mathimport timeimport torch logger = utils.MetricLogger( delimiter="\t" ) loss = 25 acc = 1 data = [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ] loader = torch.utils.data.DataLoader( data, batch_size=2 , shuffle=True ) logger.add_meter( name='loss' , meter=utils.SmoothedValue(fmt="(avg){avg:.4f} (global_avg: {global_avg:.4f})" ) ) logger.add_meter( name='acc' , meter=utils.SmoothedValue(fmt="(avg){avg:.4f} (global_avg: {global_avg:.4f})" ) )for index, obj in enumerate (logger.log_every( loader, print_freq=1 , header='Test' )): print (obj) time.sleep(2 ) loss = loss*math.exp(-1 *(index+1 )) acc = acc*math.log(index+2 ) logger.update(loss=loss) logger.update(acc=acc) print ('done' )
输出为
1 2 3 4 5 6 7 8 9 10 11 12 tensor([2, 5])Test [0/5] eta: 0:00:10 loss: (avg)9.1970 (global_avg: 9.1970) acc: (avg)0.6931 (global_avg: 0.6931) time: 2.0060 data: 0.0005 tensor([8, 1])Test [1/5] eta: 0:00:08 loss: (avg)5.2208 (global_avg: 5.2208) acc: (avg)0.7273 (global_avg: 0.7273) time: 2.0064 data: 0.0007 tensor([3, 4])Test [2/5] eta: 0:00:06 loss: (avg)3.5012 (global_avg: 3.5012) acc: (avg)0.8368 (global_avg: 0.8368) time: 2.0060 data: 0.0006 tensor([6, 9])Test [3/5] eta: 0:00:04 loss: (avg)2.6262 (global_avg: 2.6262) acc: (avg)1.0523 (global_avg: 1.0523) time: 2.0061 data: 0.0006 tensor([10, 7])Test [4/5] eta: 0:00:02 loss: (avg)2.1010 (global_avg: 2.1010) acc: (avg)1.4507 (global_avg: 1.4507) time: 2.0062 data: 0.0006Test Total time: 0:00:10 (2.0081 s / it) done
这里窗口太大看不出来global_avg和avg的区别,如果print_freq
改成2,那么
1 2 3 4 5 6 7 8 9 10 tensor([6, 1])Test [0/5] eta: 0:00:10 loss: (avg)9.1970 (global_avg: 9.1970) acc: (avg)0.6931 (global_avg: 0.6931) time: 2.0059 data: 0.0004 tensor([ 3, 10]) tensor([8, 2])Test [2/5] eta: 0:00:06 loss: (avg)3.5012 (global_avg: 3.5012) acc: (avg)0.8368 (global_avg: 0.8368) time: 2.0064 data: 0.0005 tensor([5, 9]) tensor([7, 4])Test [4/5] eta: 0:00:02 loss: (avg)2.1010 (global_avg: 2.1010) acc: (avg)1.4507 (global_avg: 1.4507) time: 2.0065 data: 0.0006Test Total time: 0:00:10 (2.0075 s / it) done
除了首次和最后一次,就是每隔两个batch打印一次。
以上便是MetricLogger的详解。