Stupid Hack for Single PyTorch Layer Quantization
Kind of.
Quantization and model packing with PyTorch and ONNX are in a weird state right now.
On one hand, everything just works for most cases for PyTorch (there are competing and unstable new APIs, but that was to be expected).
For ONNX, it also just works, but adding a single "if" to the model proved to be a challenge, forget about more complex logic. To expose or not to expose (and how to obfuscate) some logic into external wrapper utilities is a design decision (also out of scope for this short post).
The problem is, the pre-packaged versions of PyTorch do not work properly with quantized models on older CPUs (1, 2 + literally dozens of similar questions in telegram chats). Typically people report having a "10 year old laptop" with some old Intel CPU or something similar.
Of course, no one would tweak or rebuild anything. So, unless a TTS model for example is fully quantized (or somehow cleverly packaged into ONNX) it does not make sense to quantize some parts of the model or expose some logic outside of jit / pt packages even if it reduces package size significantly.
But there is a third solution. If there is a single large layer / module (e.g. nn.Embedding - the best candidate) there is a dirty hack:
- Do not quantize the model;
- Quantize the weight matrix manually;
- Save the checkpoint with int8 weights;
- Store scale and zero_point separately;
- On loading, just convert int8 into float32 manually;
(Basically the same approach as dynamic quantization).
Your mileage may vary, but basic conversions is as follows:
qmax = 127
qmin = -128
scale = (weight.max() - weight.min()) / (qmax - qmin)
zero_point = qmin - weight.min() / scale
Obviously we tried going below int8, but the dynamic range for nn.Embedding was somewhere around 2**6, so we decided not to.
If this faces some further real world hurdles, I will provide an update.
#deep_learing