Linear SVC
This documentation is for an unreleased version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

Linear Support Vector Machine #

Linear Support Vector Machine (Linear SVC) is an algorithm that attempts to find a hyperplane to maximize the distance between classified samples.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector.
labelCol Integer "label" Label to predict.
weightCol Double "weight" Weight of sample.

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Label of the max probability.
rawPredictionCol Vector "rawPrediction" Vector of the probability of each label.

Parameters #

Below are the parameters required by LinearSVCModel.

Key Default Type Required Description
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.
rawPredictionCol "rawPrediction" String no Raw prediction column name.
threshold 0.0 Double no Threshold in binary classification prediction applied to rawPrediction.

LinearSVC needs parameters above and also below.

Key Default Type Required Description
labelCol "label" String no Label column name.
weightCol null String no Weight column name.
maxIter 20 Integer no Maximum number of iterations.
reg 0. Double no Regularization parameter.
elasticNet 0. Double no ElasticNet parameter.
learningRate 0.1 Double no Learning rate of optimization method.
globalBatchSize 32 Integer no Global batch size of training algorithms.
tol 1e-6 Double no Convergence tolerance for iterative algorithms.

Examples #

import org.apache.flink.ml.classification.linearsvc.LinearSVC;
import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a LinearSVC model and uses it for classification. */
public class LinearSVCExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<Row> inputStream =
                env.fromElements(
                        Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
                        Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
                        Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.),
                        Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.),
                        Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.),
                        Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
                        Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
                        Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
                        Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
                        Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features", "label", "weight");

        // Creates a LinearSVC object and initializes its parameters.
        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");

        // Trains the LinearSVC Model.
        LinearSVCModel linearSVCModel = linearSVC.fit(inputTable);

        // Uses the LinearSVC Model for predictions.
        Table outputTable = linearSVCModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(linearSVC.getFeaturesCol());
            double expectedResult = (Double) row.getField(linearSVC.getLabelCol());
            double predictionResult = (Double) row.getField(linearSVC.getPredictionCol());
            DenseVector rawPredictionResult =
                    (DenseVector) row.getField(linearSVC.getRawPredictionCol());
            System.out.printf(
                    "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n",
                    features, expectedResult, predictionResult, rawPredictionResult);
        }
    }
}

# Simple program that trains a LinearSVC model and uses it for classification.

from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.classification.linearsvc import LinearSVC
from pyflink.table import StreamTableEnvironment

# create a new StreamExecutionEnvironment
env = StreamExecutionEnvironment.get_execution_environment()

# create a StreamTableEnvironment
t_env = StreamTableEnvironment.create(env)

# generate input data
input_table = t_env.from_data_stream(
    env.from_collection([
        (Vectors.dense([1, 2, 3, 4]), 0., 1.),
        (Vectors.dense([2, 2, 3, 4]), 0., 2.),
        (Vectors.dense([3, 2, 3, 4]), 0., 3.),
        (Vectors.dense([4, 2, 3, 4]), 0., 4.),
        (Vectors.dense([5, 2, 3, 4]), 0., 5.),
        (Vectors.dense([11, 2, 3, 4]), 1., 1.),
        (Vectors.dense([12, 2, 3, 4]), 1., 2.),
        (Vectors.dense([13, 2, 3, 4]), 1., 3.),
        (Vectors.dense([14, 2, 3, 4]), 1., 4.),
        (Vectors.dense([15, 2, 3, 4]), 1., 5.),
    ],
        type_info=Types.ROW_NAMED(
            ['features', 'label', 'weight'],
            [DenseVectorTypeInfo(), Types.DOUBLE(), Types.DOUBLE()])
    ))

# create a linear svc object and initialize its parameters
linear_svc = LinearSVC().set_weight_col('weight')

# train the linear svc model
model = linear_svc.fit(input_table)

# use the linear svc model for predictions
output = model.transform(input_table)[0]

# extract and display the results
field_names = output.get_schema().get_field_names()
for result in t_env.to_data_stream(output).execute_and_collect():
    features = result[field_names.index(linear_svc.get_features_col())]
    expected_result = result[field_names.index(linear_svc.get_label_col())]
    prediction_result = result[field_names.index(linear_svc.get_prediction_col())]
    raw_prediction_result = result[field_names.index(linear_svc.get_raw_prediction_col())]
    print('Features: ' + str(features) + ' \tExpected Result: ' + str(expected_result)
          + ' \tPrediction Result: ' + str(prediction_result)
          + ' \tRaw Prediction Result: ' + str(raw_prediction_result))