This documentation is for an unreleased version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.
Logistic Regression
Logistic Regression #
Logistic regression is a special case of the Generalized Linear Model. It is widely used to predict a binary response.
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
featuresCol | Vector | "features" |
Feature vector. |
labelCol | Integer | "label" |
Label to predict. |
weightCol | Double | "weight" |
Weight of sample. |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
predictionCol | Integer | "prediction" |
Label of the max probability. |
rawPredictionCol | Vector | "rawPrediction" |
Vector of the probability of each label. |
Parameters #
Below are the parameters required by LogisticRegressionModel
.
Key | Default | Type | Required | Description |
---|---|---|---|---|
featuresCol | "features" |
String | no | Features column name. |
predictionCol | "prediction" |
String | no | Prediction column name. |
rawPredictionCol | "rawPrediction" |
String | no | Raw prediction column name. |
LogisticRegression
needs parameters above and also below.
Key | Default | Type | Required | Description |
---|---|---|---|---|
labelCol | "label" |
String | no | Label column name. |
weightCol | null |
String | no | Weight column name. |
maxIter | 20 |
Integer | no | Maximum number of iterations. |
reg | 0. |
Double | no | Regularization parameter. |
elasticNet | 0. |
Double | no | ElasticNet parameter. |
learningRate | 0.1 |
Double | no | Learning rate of optimization method. |
globalBatchSize | 32 |
Integer | no | Global batch size of training algorithms. |
tol | 1e-6 |
Double | no | Convergence tolerance for iterative algorithms. |
multiClass | "auto" |
String | no | Classification type. Supported values: “auto”, “binomial”, “multinomial”. |
Examples #
import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
/** Simple program that trains a LogisticRegression model and uses it for classification. */
public class LogisticRegressionExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input data.
DataStream<Row> inputStream =
env.fromElements(
Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");
// Creates a LogisticRegression object and initializes its parameters.
LogisticRegression lr = new LogisticRegression().setWeightCol("weight");
// Trains the LogisticRegression Model.
LogisticRegressionModel lrModel = lr.fit(inputTable);
// Uses the LogisticRegression Model for predictions.
Table outputTable = lrModel.transform(inputTable)[0];
// Extracts and displays the results.
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol());
double expectedResult = (Double) row.getField(lr.getLabelCol());
double predictionResult = (Double) row.getField(lr.getPredictionCol());
DenseVector rawPredictionResult = (DenseVector) row.getField(lr.getRawPredictionCol());
System.out.printf(
"Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
features, expectedResult, predictionResult, rawPredictionResult);
}
}
}
# Simple program that trains a LogisticRegression model and uses it for
# classification.
from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.classification.logisticregression import LogisticRegression
from pyflink.table import StreamTableEnvironment
# create a new StreamExecutionEnvironment
env = StreamExecutionEnvironment.get_execution_environment()
# create a StreamTableEnvironment
t_env = StreamTableEnvironment.create(env)
# generate input data
input_data = t_env.from_data_stream(
env.from_collection([
(Vectors.dense([1, 2, 3, 4]), 0., 1.),
(Vectors.dense([2, 2, 3, 4]), 0., 2.),
(Vectors.dense([3, 2, 3, 4]), 0., 3.),
(Vectors.dense([4, 2, 3, 4]), 0., 4.),
(Vectors.dense([5, 2, 3, 4]), 0., 5.),
(Vectors.dense([11, 2, 3, 4]), 1., 1.),
(Vectors.dense([12, 2, 3, 4]), 1., 2.),
(Vectors.dense([13, 2, 3, 4]), 1., 3.),
(Vectors.dense([14, 2, 3, 4]), 1., 4.),
(Vectors.dense([15, 2, 3, 4]), 1., 5.),
],
type_info=Types.ROW_NAMED(
['features', 'label', 'weight'],
[DenseVectorTypeInfo(), Types.DOUBLE(), Types.DOUBLE()])
))
# create a logistic regression object and initialize its parameters
logistic_regression = LogisticRegression().set_weight_col('weight')
# train the logistic regression model
model = logistic_regression.fit(input_data)
# use the logistic regression model for predictions
output = model.transform(input_data)[0]
# extract and display the results
field_names = output.get_schema().get_field_names()
for result in t_env.to_data_stream(output).execute_and_collect():
features = result[field_names.index(logistic_regression.get_features_col())]
expected_result = result[field_names.index(logistic_regression.get_label_col())]
prediction_result = result[field_names.index(logistic_regression.get_prediction_col())]
raw_prediction_result = result[field_names.index(logistic_regression.get_raw_prediction_col())]
print('Features: ' + str(features) + ' \tExpected Result: ' + str(expected_result)
+ ' \tPrediction Result: ' + str(prediction_result)
+ ' \tRaw Prediction Result: ' + str(raw_prediction_result))
OnlineLogisticRegression #
Online Logistic Regression supports training online regression model on an unbounded stream of training data.
The online optimizer of this algorithm is The FTRL-Proximal proposed by H.Brendan McMahan et al. See H. Brendan McMahan et al., Ad click prediction: a view from the trenches.
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
featuresCol | Vector | "features" |
Feature vector |
labelCol | Integer | "label" |
Label to predict |
weightCol | Double | "weight" |
Weight of sample |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
predictionCol | Integer | "prediction" |
Label of the max probability |
rawPredictionCol | Vector | "rawPrediction" |
Vector of the probability of each label |
modelVersionCol | Long | "modelVersion" |
The version of the model data used for this prediction |
Parameters #
Below are the parameters required by OnlineLogisticRegressionModel
.
Key | Default | Type | Required | Description |
---|---|---|---|---|
featuresCol | "features" |
String | no | Features column name. |
predictionCol | "prediction" |
String | no | Prediction column name. |
rawPredictionCol | "rawPrediction" |
String | no | Raw prediction column name. |
modelVersionCol | "modelVersion" |
String | no | Model version column name. |
OnlineLogisticRegression
needs parameters above and also below.
Key | Default | Type | Required | Description |
---|---|---|---|---|
labelCol | "label" |
String | no | Label column name. |
weightCol | null |
String | no | Weight column name. |
batchStrategy | COUNT_STRATEGY |
String | no | Strategy to create mini batch from online train data. |
globalBatchSize | 32 |
Integer | no | Global batch size of training algorithms. |
reg | 0. |
Double | no | Regularization parameter. |
elasticNet | 0. |
Double | no | ElasticNet parameter. |
Examples #
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/** Simple program that trains an OnlineLogisticRegression model and uses it for classification. */
public class OnlineLogisticRegressionExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(4);
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input training and prediction data. Both are infinite streams that periodically
// sends out provided data to trigger model update and prediction.
List<Row> trainData1 =
Arrays.asList(
Row.of(Vectors.dense(0.1, 2.), 0.),
Row.of(Vectors.dense(0.2, 2.), 0.),
Row.of(Vectors.dense(0.3, 2.), 0.),
Row.of(Vectors.dense(0.4, 2.), 0.),
Row.of(Vectors.dense(0.5, 2.), 0.),
Row.of(Vectors.dense(11., 12.), 1.),
Row.of(Vectors.dense(12., 11.), 1.),
Row.of(Vectors.dense(13., 12.), 1.),
Row.of(Vectors.dense(14., 12.), 1.),
Row.of(Vectors.dense(15., 12.), 1.));
List<Row> trainData2 =
Arrays.asList(
Row.of(Vectors.dense(0.2, 3.), 0.),
Row.of(Vectors.dense(0.8, 1.), 0.),
Row.of(Vectors.dense(0.7, 1.), 0.),
Row.of(Vectors.dense(0.6, 2.), 0.),
Row.of(Vectors.dense(0.2, 2.), 0.),
Row.of(Vectors.dense(14., 17.), 1.),
Row.of(Vectors.dense(15., 10.), 1.),
Row.of(Vectors.dense(16., 16.), 1.),
Row.of(Vectors.dense(17., 10.), 1.),
Row.of(Vectors.dense(18., 13.), 1.));
List<Row> predictData =
Arrays.asList(
Row.of(Vectors.dense(0.8, 2.7), 0.0),
Row.of(Vectors.dense(15.5, 11.2), 1.0));
RowTypeInfo typeInfo =
new RowTypeInfo(
new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.DOUBLE},
new String[] {"features", "label"});
SourceFunction<Row> trainSource =
new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2));
DataStream<Row> trainStream = env.addSource(trainSource, typeInfo);
Table trainTable = tEnv.fromDataStream(trainStream).as("features");
SourceFunction<Row> predictSource =
new PeriodicSourceFunction(1000, Collections.singletonList(predictData));
DataStream<Row> predictStream = env.addSource(predictSource, typeInfo);
Table predictTable = tEnv.fromDataStream(predictStream).as("features");
// Creates an online LogisticRegression object and initializes its parameters and initial
// model data.
Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L);
Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData));
OnlineLogisticRegression olr =
new OnlineLogisticRegression()
.setFeaturesCol("features")
.setLabelCol("label")
.setPredictionCol("prediction")
.setReg(0.2)
.setElasticNet(0.5)
.setGlobalBatchSize(10)
.setInitialModelData(initModelDataTable);
// Trains the online LogisticRegression Model.
OnlineLogisticRegressionModel onlineModel = olr.fit(trainTable);
// Uses the online LogisticRegression Model for predictions.
Table outputTable = onlineModel.transform(predictTable)[0];
// Extracts and displays the results. As training data stream continuously triggers the
// update of the internal model data, raw prediction results of the same predict dataset
// would change over time.
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features = (DenseVector) row.getField(olr.getFeaturesCol());
Double expectedResult = (Double) row.getField(olr.getLabelCol());
Double predictionResult = (Double) row.getField(olr.getPredictionCol());
DenseVector rawPredictionResult = (DenseVector) row.getField(olr.getRawPredictionCol());
System.out.printf(
"Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
features, expectedResult, predictionResult, rawPredictionResult);
}
}
}