Обложка канала

Spark in me - Internet, data science, math, deep learning, philosophy

2440 @snakers4

Канал про интересные мне темы - интернет - статистика - наука о данных Без рекламы и буллшита.

Spark in me - Internet, data science, math, deep learning, philosophy

5 лет назад
Открыть в
Speeding Up Your Transformer Based Networks For Real

You know that there are a lot of pruning / distillation papers that boast 90% sparsification, but it is impossible to use this in production? Looks like I stumbled upon some decent recipe.

Well, you can take a well-trained transformer based network and just replace all of the Linear layers with their low-rank counterparts initialized using SVD (spectral initialization).

But does it really work?

I tested it today, it worked.

How well does it work?

On a not-quite trained network my loss after 1 epoch of tuning was 25-30% higher with a factorized model compared to a full model. I took 25% of eigenvalues and they amounted for about ~50% of total.

The metrics took a significant hit, but on simpler classification tasks it should work better.

What are the benefits?

25% factorization (i.e. taking only 25% of all eigenvalues) produces a model that is 50% smaller.

Is it worth it?

The main question is that can you take either (i) a poorly trained or (ii) well-trained network and train it until it reaches the same numbers as the full model.

This remains to be seen.

class FactorizedLinear(nn.Module):
def __init__(self,
or_linear,
dim_ratio=1.0):
super().__init__()
self.bias = nn.parameter.Parameter(or_linear.bias.data, requires_grad=True)
u, vh = self.spectral_init(or_linear.weight.data, dim_ratio=dim_ratio)
print(f'Doing SVD of tensor {or_linear.weight.shape}, U: {u.shape}, Vh: {vh.shape}')
self.u = nn.parameter.Parameter(u, requires_grad=True)
self.vh = nn.parameter.Parameter(vh, requires_grad=True)
self.dim_ratio = dim_ratio
self.in_features = u.size(0)
self.out_features = vh.size(1)

@staticmethod
def spectral_init(m,
dim_ratio=1):
u, s, vh = torch.linalg.svd(m, full_matrices=False)
u = u @ torch.diag(torch.sqrt(s))
vh = torch.diag(torch.sqrt(s)) @ vh
if dim_ratio < 1:
dims = int(u.size(1) * dim_ratio)
u = u[:, :dims]
vh = vh[:dims, :]
s_share = s[:dims].sum() / s.sum() * 100
print(f'SVD eigenvalue share {s_share:.2f}%')
return u, vh

def extra_repr(self) -> str:
return (f'in_features={self.in_features}, '
f'out_features={self.out_features}, '
f'bias=True, dim_ratio={self.dim_ratio}')

def forward(self, x):
return x @ (self.u @ self.vh).transpose(0, 1) + self.bias

# deep_learning