This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.
Kmeans
K-means #
K-means is a commonly-used clustering algorithm. It groups given data points into a predefined number of clusters.
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
featuresCol | Vector | "features" |
Feature vector |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
predictionCol | Integer | "prediction" |
Predicted cluster center |
Parameters #
Below are parameters required by KMeansModel
.
Key | Default | Type | Required | Description |
---|---|---|---|---|
distanceMeasure | EuclideanDistanceMeasure.NAME |
String | no | Distance measure. Supported values: EuclideanDistanceMeasure.NAME |
featuresCol | "features" |
String | no | Features column name. |
predictionCol | "prediction" |
String | no | Prediction column name. |
KMeans
needs parameters above and also below.
Key | Default | Type | Required | Description |
---|---|---|---|---|
initMode | "random" |
String | no | The initialization algorithm. Supported options: ‘random’. |
seed | null |
Long | no | The random seed. |
maxIter | 20 |
Integer | no | Maximum number of iterations. |
Examples #
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
// Generates train data and predict data.
DataStream<DenseVector> inputStream = env.fromElements(
Vectors.dense(0.0, 0.0),
Vectors.dense(0.0, 0.3),
Vectors.dense(0.3, 0.0),
Vectors.dense(9.0, 0.0),
Vectors.dense(9.0, 0.6),
Vectors.dense(9.6, 0.0)
);
Table input = tEnv.fromDataStream(inputStream).as("features");
// Creates a K-means object and initialize its parameters.
KMeans kmeans = new KMeans()
.setK(2)
.setSeed(1L);
// Trains the K-means Model.
KMeansModel model = kmeans.fit(input);
// Uses the K-means Model to do predictions.
Table output = model.transform(input)[0];
// Extracts and displays prediction result.
for (CloseableIterator<Row> it = output.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector vector = (DenseVector) row.getField("features");
int clusterId = (Integer) row.getField("prediction");
System.out.println("Vector: " + vector + "\tCluster ID: " + clusterId);
}