Kmeans
This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

K-means #

K-means is a commonly-used clustering algorithm. It groups given data points into a predefined number of clusters.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted cluster center

Parameters #

Below are parameters required by KMeansModel.

Key Default Type Required Description
distanceMeasure EuclideanDistanceMeasure.NAME String no Distance measure. Supported values: EuclideanDistanceMeasure.NAME
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.

KMeans needs parameters above and also below.

Key Default Type Required Description
initMode "random" String no The initialization algorithm. Supported options: ‘random’.
seed null Long no The random seed.
maxIter 20 Integer no Maximum number of iterations.

Examples #

import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;

// Generates train data and predict data.
DataStream<DenseVector> inputStream = env.fromElements(
  Vectors.dense(0.0, 0.0),
  Vectors.dense(0.0, 0.3),
  Vectors.dense(0.3, 0.0),
  Vectors.dense(9.0, 0.0),
  Vectors.dense(9.0, 0.6),
  Vectors.dense(9.6, 0.0)
);
Table input = tEnv.fromDataStream(inputStream).as("features");

// Creates a K-means object and initialize its parameters.
KMeans kmeans = new KMeans()
  .setK(2)
  .setSeed(1L);

// Trains the K-means Model.
KMeansModel model = kmeans.fit(input);

// Uses the K-means Model to do predictions.
Table output = model.transform(input)[0];

// Extracts and displays prediction result.
for (CloseableIterator<Row> it = output.execute().collect(); it.hasNext(); ) {
  Row row = it.next();
  DenseVector vector = (DenseVector) row.getField("features");
  int clusterId = (Integer) row.getField("prediction");
  System.out.println("Vector: " + vector + "\tCluster ID: " + clusterId);
}