When you have 2+ losses in your NN, sometimes loss weighting is not really straightforward. Usually total loss is:
loss = loss_0 + lambda_1 * loss_1 + ...
Of course you can tune these "lambdas" manually or using some naïve NAS (or some ad hoc heuristic, i.e. this loss more important), but all these approaches have 2 drawbacks:
- Slow / compute intensive / ad hoc;
- There is no guarantee that these values are always optimal;
Usually when something is not stable (and multiple losses often explode on init) some sort of adaptive clipping is employed. I just stumbled upon a technique called Gradient Adaptive Factor, see an example here.
The idea is simple - balance your losses so that their gradient sizes are roughly similar.
#deep_learning