Speeding Up Your Transformer Based Networks For Real
You know that there are a lot of pruning / distillation papers that boast 90% sparsification, but it is impossible to use this in production? Looks like I stumbled upon some decent recipe.
Well, you can take a well-trained transformer based network and just replace all of the Linear layers with their low-rank counterparts initialized using SVD (spectral initialization).
But does it really work?
I tested it today, it worked.
How well does it work?
On a not-quite trained network my loss after 1 epoch of tuning was 25-30% higher with a factorized model compared to a full model. I took 25% of eigenvalues and they amounted for about ~50% of total.
The metrics took a significant hit, but on simpler classification tasks it should work better.
What are the benefits?
25% factorization (i.e. taking only 25% of all eigenvalues) produces a model that is 50% smaller.
Is it worth it?
The main question is that can you take either (i) a poorly trained or (ii) well-trained network and train it until it reaches the same numbers as the full model.
This remains to be seen.
class FactorizedLinear(nn.Module):
def __init__(self,
or_linear,
dim_ratio=1.0):
super().__init__()
self.bias = nn.parameter.Parameter(or_linear.bias.data, requires_grad=True)
u, vh = self.spectral_init(or_linear.weight.data, dim_ratio=dim_ratio)
print(f'Doing SVD of tensor {or_linear.weight.shape}, U: {u.shape}, Vh: {vh.shape}')
self.u = nn.parameter.Parameter(u, requires_grad=True)
self.vh = nn.parameter.Parameter(vh, requires_grad=True)
self.dim_ratio = dim_ratio
self.in_features = u.size(0)
self.out_features = vh.size(1)
@staticmethod
def spectral_init(m,
dim_ratio=1):
u, s, vh = torch.linalg.svd(m, full_matrices=False)
u = u @ torch.diag(torch.sqrt(s))
vh = torch.diag(torch.sqrt(s)) @ vh
if dim_ratio < 1:
dims = int(u.size(1) * dim_ratio)
u = u[:, :dims]
vh = vh[:dims, :]
s_share = s[:dims].sum() / s.sum() * 100
print(f'SVD eigenvalue share {s_share:.2f}%')
return u, vh
def extra_repr(self) -> str:
return (f'in_features={self.in_features}, '
f'out_features={self.out_features}, '
f'bias=True, dim_ratio={self.dim_ratio}')
def forward(self, x):
return x @ (self.u @ self.vh).transpose(0, 1) + self.bias
# deep_learning