Kmeans
This documentation is for an unreleased version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

K-means #

K-means is a commonly-used clustering algorithm. It groups given data points into a predefined number of clusters.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector.

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted cluster center.

Parameters #

Below are the parameters required by KMeansModel.

Key Default Type Required Description
distanceMeasure euclidean String no Distance measure. Supported values: 'euclidean', 'manhattan', 'cosine'.
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.
k 2 Integer no The max number of clusters to create.

KMeans needs parameters above and also below.

Key Default Type Required Description
initMode "random" String no The initialization algorithm. Supported options: ‘random’.
seed null Long no The random seed.
maxIter 20 Integer no Maximum number of iterations.

Examples #

import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a KMeans model and uses it for clustering. */
public class KMeansExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input data.
        DataStream<DenseVector> inputStream =
                env.fromElements(
                        Vectors.dense(0.0, 0.0),
                        Vectors.dense(0.0, 0.3),
                        Vectors.dense(0.3, 0.0),
                        Vectors.dense(9.0, 0.0),
                        Vectors.dense(9.0, 0.6),
                        Vectors.dense(9.6, 0.0));
        Table inputTable = tEnv.fromDataStream(inputStream).as("features");

        // Creates a K-means object and initializes its parameters.
        KMeans kmeans = new KMeans().setK(2).setSeed(1L);

        // Trains the K-means Model.
        KMeansModel kmeansModel = kmeans.fit(inputTable);

        // Uses the K-means Model for predictions.
        Table outputTable = kmeansModel.transform(inputTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row = it.next();
            DenseVector features = (DenseVector) row.getField(kmeans.getFeaturesCol());
            int clusterId = (Integer) row.getField(kmeans.getPredictionCol());
            System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
        }
    }
}

# Simple program that trains a KMeans model and uses it for clustering.

from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.clustering.kmeans import KMeans
from pyflink.table import StreamTableEnvironment

# create a new StreamExecutionEnvironment
env = StreamExecutionEnvironment.get_execution_environment()

# create a StreamTableEnvironment
t_env = StreamTableEnvironment.create(env)

# generate input data
input_data = t_env.from_data_stream(
    env.from_collection([
        (Vectors.dense([0.0, 0.0]),),
        (Vectors.dense([0.0, 0.3]),),
        (Vectors.dense([0.3, 3.0]),),
        (Vectors.dense([9.0, 0.0]),),
        (Vectors.dense([9.0, 0.6]),),
        (Vectors.dense([9.6, 0.0]),),
    ],
        type_info=Types.ROW_NAMED(
            ['features'],
            [DenseVectorTypeInfo()])))

# create a kmeans object and initialize its parameters
kmeans = KMeans().set_k(2).set_seed(1)

# train the kmeans model
model = kmeans.fit(input_data)

# use the kmeans model for predictions
output = model.transform(input_data)[0]

# extract and display the results
field_names = output.get_schema().get_field_names()
for result in t_env.to_data_stream(output).execute_and_collect():
    features = result[field_names.index(kmeans.get_features_col())]
    cluster_id = result[field_names.index(kmeans.get_prediction_col())]
    print('Features: ' + str(features) + ' \tCluster Id: ' + str(cluster_id))

Online K-means #

Online K-Means extends the function of K-Means, supporting to train a K-Means model continuously according to an unbounded stream of train data.

Online K-Means makes updates with the “mini-batch” K-Means rule, generalized to incorporate forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, Online K-Means computes the new centroids from the weighted average between the original and the estimated centroids. The weight of the estimated centroids is the number of points assigned to them. The weight of the original centroids is also the number of points, but additionally multiplying with the decay factor.

The decay factor scales the contribution of the clusters as estimated thus far. If the decay factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are determined entirely by recent data. Lower values correspond to more forgetting.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted cluster center

Parameters #

Below are the parameters required by OnlineKMeansModel.

Key Default Type Required Description
distanceMeasure euclidean String no Distance measure. Supported values: 'euclidean', 'manhattan', 'cosine'.
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.
k 2 Integer no The max number of clusters to create.

OnlineKMeans needs parameters above and also below.

Key Default Type Required Description
batchStrategy COUNT_STRATEGY String no Strategy to create mini batch from online train data.
globalBatchSize 32 Integer no Global batch size of training algorithms.
decayFactor 0. Double no The forgetfulness of the previous centroids.
seed null Long no The random seed.

Examples #

import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.examples.util.PeriodicSourceFunction;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/** Simple program that trains an OnlineKMeans model and uses it for clustering. */
public class OnlineKMeansExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data. Both are infinite streams that periodically
        // sends out provided data to trigger model update and prediction.
        List<Row> trainData1 =
                Arrays.asList(
                        Row.of(Vectors.dense(0.0, 0.0)),
                        Row.of(Vectors.dense(0.0, 0.3)),
                        Row.of(Vectors.dense(0.3, 0.0)),
                        Row.of(Vectors.dense(9.0, 0.0)),
                        Row.of(Vectors.dense(9.0, 0.6)),
                        Row.of(Vectors.dense(9.6, 0.0)));

        List<Row> trainData2 =
                Arrays.asList(
                        Row.of(Vectors.dense(10.0, 100.0)),
                        Row.of(Vectors.dense(10.0, 100.3)),
                        Row.of(Vectors.dense(10.3, 100.0)),
                        Row.of(Vectors.dense(-10.0, -100.0)),
                        Row.of(Vectors.dense(-10.0, -100.6)),
                        Row.of(Vectors.dense(-10.6, -100.0)));

        List<Row> predictData =
                Arrays.asList(
                        Row.of(Vectors.dense(10.0, 10.0)), Row.of(Vectors.dense(-10.0, 10.0)));

        SourceFunction<Row> trainSource =
                new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2));
        DataStream<Row> trainStream =
                env.addSource(trainSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features");

        SourceFunction<Row> predictSource =
                new PeriodicSourceFunction(1000, Collections.singletonList(predictData));
        DataStream<Row> predictStream =
                env.addSource(predictSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE));
        Table predictTable = tEnv.fromDataStream(predictStream).as("features");

        // Creates an online K-means object and initializes its parameters and initial model data.
        OnlineKMeans onlineKMeans =
                new OnlineKMeans()
                        .setFeaturesCol("features")
                        .setPredictionCol("prediction")
                        .setGlobalBatchSize(6)
                        .setInitialModelData(
                                KMeansModelData.generateRandomModelData(tEnv, 2, 2, 0.0, 0));

        // Trains the online K-means Model.
        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);

        // Uses the online K-means Model for predictions.
        Table outputTable = onlineModel.transform(predictTable)[0];

        // Extracts and displays the results. As training data stream continuously triggers the
        // update of the internal k-means model data, clustering results of the same predict dataset
        // would change over time.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row1 = it.next();
            DenseVector features1 = (DenseVector) row1.getField(onlineKMeans.getFeaturesCol());
            Integer clusterId1 = (Integer) row1.getField(onlineKMeans.getPredictionCol());
            Row row2 = it.next();
            DenseVector features2 = (DenseVector) row2.getField(onlineKMeans.getFeaturesCol());
            Integer clusterId2 = (Integer) row2.getField(onlineKMeans.getPredictionCol());
            if (Objects.equals(clusterId1, clusterId2)) {
                System.out.printf("%s and %s are now in the same cluster.\n", features1, features2);
            } else {
                System.out.printf(
                        "%s and %s are now in different clusters.\n", features1, features2);
            }
        }
    }
}