public interface LossFunction
extends scala.Serializable
A loss function determines the loss term $L(w) of the objective function $f(w) = L(w) + \lambda R(w)$ for prediction tasks, the other being regularization, $R(w)$.
The regularization is specific to the used optimization algorithm and, thus, implemented there.
We currently only support differentiable loss functions, in the future this class could be changed to DiffLossFunction in order to support other types, such as absolute loss.
Modifier and Type | Method and Description |
---|---|
WeightVector |
gradient(LabeledVector dataPoint,
WeightVector weightVector)
Calculates the gradient of the loss function given a data point and weight vector
|
double |
loss(LabeledVector dataPoint,
WeightVector weightVector)
Calculates the loss given the prediction and label value
|
scala.Tuple2<Object,WeightVector> |
lossGradient(LabeledVector dataPoint,
WeightVector weightVector)
Calculates the gradient as well as the loss given a data point and the weight vector
|
double loss(LabeledVector dataPoint, WeightVector weightVector)
dataPoint
- weightVector
- WeightVector gradient(LabeledVector dataPoint, WeightVector weightVector)
dataPoint
- weightVector
- scala.Tuple2<Object,WeightVector> lossGradient(LabeledVector dataPoint, WeightVector weightVector)
dataPoint
- weightVector
- Copyright © 2014–2017 The Apache Software Foundation. All rights reserved.