pytorch的量化模块在observer中实现,不同的observer使用不同的量化方法,除了默认的max等方法,还有比较妙的HistogramObserver,正好阅读下代码涨涨姿势


特点

HistogramObserver的特点:

  • 使用直方图存储输入数据(减少存储和计算量)
  • 假设均匀分布,使用定积分计算l2范数误差(l2 norm error)
  • 非线性查找量化阈值(threshold),等数据密度查找

相关源码位置

https://pytorch.org/docs/stable/_modules/torch/quantization/observer.html#HistogramObserver

https://github.com/pytorch/pytorch/blob/master/torch/quantization/observer.py#L725


接口

从接口来看,主要有

__init__
_non_linear_param_search
_compute_quantization_error(next_start_bin, next_end_bin, norm_type)
_adjust_min_max(self, combined_min, combined_max, upsample_rate)
_combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins)
forward(self, x_orig)
calculate_qparams
//----
_save_to_state_dict
_load_from_state_dict

其中对外的接口是forward和calculate_qparams:

forward即处理一批新的数据,因为用于量化校准的数据可能很多,比如共m批,每批的batchsize为n,一次处理不过来,即需要多次调用forward。

calculate_qparams即输出校准后的参数,即根据l2 norm误差使用非线性查找(_non_linear_param_search)找到的最好的阈值

__init__即初始化,存放一些参数

_non_linear_param_search,非线性查找阈值

_adjust_min_max,来新一批数据时调整整体最小最大值

_combine_histograms,合并直方图,不同批次数据的最大最小值一般不同,因而需要合并直方图,下文会说


合并直方图

不同批次数据的最大最小值一般不同,因而需要合并直方图。

pytorch实现的直方图合并的方法是,先将原直方图(默认2048格)上采用(默认128倍),而后填入新数据,再下采样到原格数。

非线性查找

_non_linear_param_search实现了根据数据分布密度非均匀查找阈值。具体来说,使用alpha和beta表示上界和下界,初始为1和0,每次上界减少下界增加stepsize(1e-5),根据上界和下界确定选取的start_bin和end_bin,如果start_bin和end_bin没变化,alpha和beta继续变化stepsize;否则计算量化误差(_compute_quantization_error)。每次计算量化误差和先前计算的误差比较,若减小则继续移动上界和下界;否则结束。

l2 norm的计算

虽然使用了直方图近似表示输入数据,但计算l2 norm时没有使用直方图格子中点代表这个格子所有数值,而是假设了数据在这个直方图中均匀分布,使用定积分计算l2 norm。

数学好的同学可能一眼就知道是怎么算的,数学不太好的同学(比如我)可能要琢磨一会儿。

首先给出通常的l2 norm公式

$$\sum_{i}{(x_i-y_i)^2}$$

[latex]E=mc^2[/latex]