Building your own Flink ML project

Building your own Flink ML project #

This document provides a quick introduction to using Flink ML. Readers of this document will be guided to create a simple Flink job that trains a Machine Learning Model and use it to provide prediction service.

Maven Setup #

In order to use Flink ML in a Maven project, add the following dependencies to pom.xml.

<dependency>
    <groupId>org.apache.flink</groupId>
    <artifactId>flink-ml-uber</artifactId>
    <version>2.1.0</version>
</dependency>

The example code provided in this document requires additional dependencies on the Flink Table API. In order to execute the example code successfully, please make sure the following dependencies also exist in pom.xml.

<dependency>
  <groupId>org.apache.flink</groupId>
  <artifactId>flink-clients</artifactId>
  <version>1.15.0</version>
</dependency>
<dependency>
  <groupId>org.apache.flink</groupId>
  <artifactId>flink-table-planner-loader</artifactId>
  <version>1.15.0</version>
</dependency>

Kmeans is a widely-used clustering algorithm and has been supported by Flink ML. The example code below creates a Flink job with Flink ML that initializes and trains a Kmeans model, and finally uses it to predict the cluster id of certain data points.

import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

public class QuickStart {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        String featuresCol = "features";
        String predictionCol = "prediction";

        // Generate train data and predict data as DataStream.
        DataStream<DenseVector> inputStream = env.fromElements(
                Vectors.dense(0.0, 0.0),
                Vectors.dense(0.0, 0.3),
                Vectors.dense(0.3, 0.0),
                Vectors.dense(9.0, 0.0),
                Vectors.dense(9.0, 0.6),
                Vectors.dense(9.6, 0.0)
        );

        // Convert data from DataStream to Table, as Flink ML uses Table API.
        Table input = tEnv.fromDataStream(inputStream).as(featuresCol);

        // Creates a K-means object and initialize its parameters.
        KMeans kmeans = new KMeans()
                .setK(2)
                .setSeed(1L)
                .setFeaturesCol(featuresCol)
                .setPredictionCol(predictionCol);

        // Trains the K-means Model.
        KMeansModel model = kmeans.fit(input);

        // Use the K-means Model for predictions.
        Table output = model.transform(input)[0];

        // Extracts and displays prediction result.
        for (CloseableIterator<Row> it = output.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector vector = (DenseVector) row.getField(featuresCol);
            int clusterId = (Integer) row.getField(predictionCol);
            System.out.println("Vector: " + vector + "\tCluster ID: " + clusterId);
        }
    }
}

After placing the code above into your Maven project and executing it, information like below will be printed out to your terminal window.

Vector: [0.3, 0.0]	Cluster ID: 1
Vector: [9.6, 0.0]	Cluster ID: 0
Vector: [9.0, 0.6]	Cluster ID: 0
Vector: [0.0, 0.0]	Cluster ID: 1
Vector: [0.0, 0.3]	Cluster ID: 1
Vector: [9.0, 0.0]	Cluster ID: 0

Breaking Down The Code #

The Execution Environment #

The first lines set up the StreamExecutionEnvironment to execute the Flink ML job. You would have been familiar with this concept if you have experience using Flink. For the example program in this document, a simple StreamExecutionEnvironment without specific configurations would be enough.

Given that Flink ML uses Flink’s Table API, a StreamTableEnvironment would also be necessary for the following program.

StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

Creating Training & Inference Data Table #

Then the program creates the Table containing data for the training and prediction process of the following Kmeans algorithm. Flink ML operators search the names of the columns of the input table for input data, and produce prediction results to designated column of the output Table.

DataStream<DenseVector> inputStream = env.fromElements(
        Vectors.dense(0.0, 0.0),
        Vectors.dense(0.0, 0.3),
        Vectors.dense(0.3, 0.0),
        Vectors.dense(9.0, 0.0),
        Vectors.dense(9.0, 0.6),
        Vectors.dense(9.6, 0.0)
);

Table input = tEnv.fromDataStream(inputStream).as(featuresCol);

Creating, Configuring, Training & Using Kmeans #

Flink ML classes for Kmeans algorithm include KMeans and KMeansModel. KMeans implements the training process of Kmeans algorithm based on the provided training data, and finally generates a KMeansModel. KmeansModel.transform() method encodes the Transformation logic of this algorithm and is used for predictions.

Both KMeans and KMeansModel provides getter/setter methods for Kmeans algorithm’s configuration parameters. The example program explicitly sets the following parameters, and other configuration parameters will have their default values used.

  • K, the number of clusters to create
  • seed, the random seed to initialize cluster centers
  • featuresCol, name of the column containing input feature vectors
  • predictionCol, name of the column to output prediction results

When the program invokes KMeans.fit() to generate a KMeansModel, the KMeansModel will inherit the KMeans object’s configuration parameters. Thus it is supported to set KMeansModel’s parameters directly in KMeans object.

KMeans kmeans = new KMeans()
        .setK(2)
        .setSeed(1L)
        .setFeaturesCol(featuresCol)
        .setPredictionCol(predictionCol);

KMeansModel model = kmeans.fit(input);

Table output = model.transform(input)[0];

Collecting Prediction Result #

Like all other Flink programs, the codes described in the sections above only configures the computation graph of a Flink job, and the program only evaluates the computation logic and collects outputs after the execute() method is invoked. Collected outputs from the output table would be Rows in which featuresCol contains input feature vectors, and predictionCol contains output prediction results, i.e., cluster IDs.

for (CloseableIterator<Row> it = output.execute().collect(); it.hasNext(); ) {
    Row row = it.next();
    DenseVector vector = (DenseVector) row.getField(featuresCol);
    int clusterId = (Integer) row.getField(predictionCol);
    System.out.println("Vector: " + vector + "\tCluster ID: " + clusterId);
}
Vector: [0.3, 0.0]	Cluster ID: 1
Vector: [9.6, 0.0]	Cluster ID: 0
Vector: [9.0, 0.6]	Cluster ID: 0
Vector: [0.0, 0.0]	Cluster ID: 1
Vector: [0.0, 0.3]	Cluster ID: 1
Vector: [9.0, 0.0]	Cluster ID: 0