This documentation is for an unreleased version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.
AgglomerativeClustering
AgglomerativeClustering #
AgglomerativeClustering performs a hierarchical clustering using a bottom-up approach. Each observation starts in its own cluster and the clusters are merged together one by one.
The output contains two tables. The first one assigns one cluster Id for each data point. The second one contains the information of merging two clusters at each step. The data format of the merging information is (clusterId1, clusterId2, distance, sizeOfMergedCluster).
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
featuresCol | Vector | "features" |
Feature vector. |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
predictionCol | Integer | "prediction" |
Predicted cluster center. |
Parameters #
Key | Default | Type | Required | Description |
---|---|---|---|---|
numClusters | 2 |
Integer | no | The max number of clusters to create. |
distanceThreshold | null |
Double | no | Threshold to decide whether two clusters should be merged. |
linkage | "ward" |
String | no | Criterion for computing distance between two clusters. |
computeFullTree | false |
Boolean | no | Whether computes the full tree after convergence. |
distanceMeasure | "euclidean" |
String | no | Distance measure. |
featuresCol | "features" |
String | no | Features column name. |
predictionCol | "prediction" |
String | no | Prediction column name. |
windows | GlobalWindows.getInstance() |
Windows | no | Windowing strategy that determines how to create mini-batches from input data. |
Examples #
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
/** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */
public class AgglomerativeClusteringExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input data.
DataStream<DenseVector> inputStream =
env.fromElements(
Vectors.dense(1, 1),
Vectors.dense(1, 4),
Vectors.dense(1, 0),
Vectors.dense(4, 1.5),
Vectors.dense(4, 4),
Vectors.dense(4, 0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features");
// Creates an AgglomerativeClustering object and initializes its parameters.
AgglomerativeClustering agglomerativeClustering =
new AgglomerativeClustering()
.setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)
.setDistanceMeasure(EuclideanDistanceMeasure.NAME)
.setPredictionCol("prediction");
// Uses the AgglomerativeClustering object for clustering.
Table[] outputs = agglomerativeClustering.transform(inputTable);
// Extracts and displays the results.
for (CloseableIterator<Row> it = outputs[0].execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features =
(DenseVector) row.getField(agglomerativeClustering.getFeaturesCol());
int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol());
System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
}
}
}
# Simple program that creates an agglomerativeclustering instance and uses it for clustering.
from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.clustering.agglomerativeclustering import AgglomerativeClustering
from pyflink.table import StreamTableEnvironment
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
# Creates a new StreamExecutionEnvironment.
env = StreamExecutionEnvironment.get_execution_environment()
# Creates a StreamTableEnvironment.
t_env = StreamTableEnvironment.create(env)
# Generates input data.
input_data = t_env.from_data_stream(
env.from_collection([
(Vectors.dense([1, 1]),),
(Vectors.dense([1, 4]),),
(Vectors.dense([1, 0]),),
(Vectors.dense([4, 1.5]),),
(Vectors.dense([4, 4]),),
(Vectors.dense([4, 0]),),
],
type_info=Types.ROW_NAMED(
['features'],
[DenseVectorTypeInfo()])))
# Creates an AgglomerativeClustering object and initializes its parameters.
agglomerative_clustering = AgglomerativeClustering() \
.set_linkage('ward') \
.set_distance_measure('euclidean') \
.set_prediction_col('prediction')
# Uses the AgglomerativeClustering for clustering.
outputs = agglomerative_clustering.transform(input_data)
# Extracts and display the clustering results.
field_names = outputs[0].get_schema().get_field_names()
for result in t_env.to_data_stream(outputs[0]).execute_and_collect():
features = result[field_names.index(agglomerative_clustering.features_col)]
cluster_id = result[field_names.index(agglomerative_clustering.prediction_col)]
print('Features: ' + str(features) + '\tCluster ID: ' + str(cluster_id))
# Visualizes the merge info.
merge_info = [result for result in
t_env.to_data_stream(outputs[1]).execute_and_collect()]
plt.title("Agglomerative Clustering Dendrogram")
dendrogram(merge_info)
plt.xlabel("Index of data point.")
plt.ylabel("Distances between merged clusters.")
plt.show()