Logistic Regression
This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

Logistic Regression #

Logistic regression is a special case of Generalized Linear Model. It is widely used to predict a binary response.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector
labelCol Integer "label" Label to predict
weightCol Double "weight" Weight of sample

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Label of the max probability
rawPredictionCol Vector "rawPrediction" Vector of the probability of each label

Parameters #

Below are parameters required by LogisticRegressionModel.

Key Default Type Required Description
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.
rawPredictionCol "rawPrediction" String no Raw prediction column name.

LogisticRegression needs parameters above and also below.

Key Default Type Required Description
labelCol "label" String no Label column name.
weightCol null String no Weight column name.
maxIter 20 Integer no Maximum number of iterations.
reg 0. Double no Regularization parameter.
learningRate 0.1 Double no Learning rate of optimization method.
globalBatchSize 32 Integer no Global batch size of training algorithms.
tol 1e-6 Double no Convergence tolerance for iterative algorithms.
multiClass "auto" String no Classification type. Supported values: “auto”, “binomial”, “multinomial”

Examples #

import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;

List<Row> binomialTrainData =
  Arrays.asList(
  Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
  Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
  Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
  Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
  Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
  Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
  Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
  Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
  Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
  Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
Collections.shuffle(binomialTrainData);

Table binomialDataTable =
  tEnv.fromDataStream(
  env.fromCollection(
    binomialTrainData,
    new RowTypeInfo(
      new TypeInformation[] {
        TypeInformation.of(DenseVector.class),
        Types.DOUBLE,
        Types.DOUBLE
      },
      new String[] {"features", "label", "weight"})));

LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight");
LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
Table output = model.transform(binomialDataTable)[0];

output.execute().print();