pytorch的量化模块在observer中实现,不同的observer使用不同的量化方法,除了默认的max等方法,还有比较妙的HistogramObserver,正好阅读下代码涨涨姿势
特点
HistogramObserver的特点:
- 使用直方图存储输入数据(减少存储和计算量)
- 假设均匀分布,使用定积分计算l2范数误差(l2 norm error)
- 非线性查找量化阈值(threshold),等数据密度查找
相关源码位置
https://pytorch.org/docs/stable/_modules/torch/quantization/observer.html#HistogramObserver
https://github.com/pytorch/pytorch/blob/master/torch/quantization/observer.py#L725
接口
从接口来看,主要有
__init__
_non_linear_param_search
_compute_quantization_error(next_start_bin, next_end_bin, norm_type)
_adjust_min_max(self, combined_min, combined_max, upsample_rate)
_combine_histograms(self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins)
forward(self, x_orig)
calculate_qparams
//----
_save_to_state_dict
_load_from_state_dict
其中对外的接口是forward和calculate_qparams:
forward即处理一批新的数据,因为用于量化校准的数据可能很多,比如共m批,每批的batchsize为n,一次处理不过来,即需要多次调用forward。
calculate_qparams即输出校准后的参数,即根据l2 norm误差使用非线性查找(_non_linear_param_search)找到的最好的阈值
__init__即初始化,存放一些参数
_non_linear_param_search,非线性查找阈值
_adjust_min_max,来新一批数据时调整整体最小最大值
_combine_histograms,合并直方图,不同批次数据的最大最小值一般不同,因而需要合并直方图,下文会说
合并直方图
不同批次数据的最大最小值一般不同,因而需要合并直方图。
pytorch实现的直方图合并的方法是,先将原直方图(默认2048格)上采用(默认128倍),而后填入新数据,再下采样到原格数。
非线性查找
_non_linear_param_search实现了根据数据分布密度非均匀查找阈值。具体来说,使用alpha和beta表示上界和下界,初始为1和0,每次上界减少下界增加stepsize(1e-5),根据上界和下界确定选取的start_bin和end_bin,如果start_bin和end_bin没变化,alpha和beta继续变化stepsize;否则计算量化误差(_compute_quantization_error)。每次计算量化误差和先前计算的误差比较,若减小则继续移动上界和下界;否则结束。
l2 norm的计算
虽然使用了直方图近似表示输入数据,但计算l2 norm时没有使用直方图格子中点代表这个格子所有数值,而是假设了数据在这个直方图中均匀分布,使用定积分计算l2 norm。
数学好的同学可能一眼就知道是怎么算的,数学不太好的同学(比如我)可能要琢磨一会儿。
首先给出通常的l2 norm公式
$$\sum_{i}{(x_i-y_i)^2}$$
[latex]E=mc^2[/latex]