public class L2Regularization extends Object
L_2
regularization penalty.
The regularization function is the square of the L2 norm 1/2*||w||_2^2
with w
being the weight vector. The function penalizes large weights,
favoring solutions with more small weights rather than few large ones.
Constructor and Description |
---|
L2Regularization() |
Modifier and Type | Method and Description |
---|---|
static double |
regLoss(double oldLoss,
Vector weightVector,
double regularizationConstant)
Adds regularization to the loss value
|
static Vector |
takeStep(Vector weightVector,
Vector gradient,
double regularizationConstant,
double learningRate)
Calculates the new weights based on the gradient and L2 regularization penalty
|
public static Vector takeStep(Vector weightVector, Vector gradient, double regularizationConstant, double learningRate)
The updated weight is w - learningRate * (gradient + lambda * w)
where
w
is the weight vector, and lambda
is the regularization parameter.
weightVector
- The weights to be updatedgradient
- The gradient according to which we will update the weightsregularizationConstant
- The regularization parameter to be appliedlearningRate
- The effective step size for this iterationpublic static double regLoss(double oldLoss, Vector weightVector, double regularizationConstant)
The updated loss is oldLoss + lambda * 1/2*||w||_2^2
where
w
is the weight vector, and lambda
is the regularization parameter
oldLoss
- The loss to be updatedweightVector
- The weights used to update the lossregularizationConstant
- The regularization parameter to be appliedCopyright © 2014–2018 The Apache Software Foundation. All rights reserved.