import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering;
import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
/** Simple program that creates an AgglomerativeClustering instance and uses it for clustering. */
public class AgglomerativeClusteringExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input data.
DataStream<DenseVector> inputStream =
env.fromElements(
Vectors.dense(1, 1),
Vectors.dense(1, 4),
Vectors.dense(1, 0),
Vectors.dense(4, 1.5),
Vectors.dense(4, 4),
Vectors.dense(4, 0));
Table inputTable = tEnv.fromDataStream(inputStream).as("features");
// Creates an AgglomerativeClustering object and initializes its parameters.
AgglomerativeClustering agglomerativeClustering =
new AgglomerativeClustering()
.setLinkage(AgglomerativeClusteringParams.LINKAGE_WARD)
.setDistanceMeasure(EuclideanDistanceMeasure.NAME)
.setPredictionCol("prediction");
// Uses the AgglomerativeClustering object for clustering.
Table[] outputs = agglomerativeClustering.transform(inputTable);
// Extracts and displays the results.
for (CloseableIterator<Row> it = outputs[0].execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector features =
(DenseVector) row.getField(agglomerativeClustering.getFeaturesCol());
int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol());
System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId);
}
}
}
# Simple program that creates an agglomerativeclustering instance and uses it for clustering.
from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.clustering.agglomerativeclustering import AgglomerativeClustering
from pyflink.table import StreamTableEnvironment
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
# Creates a new StreamExecutionEnvironment.
env = StreamExecutionEnvironment.get_execution_environment()
# Creates a StreamTableEnvironment.
t_env = StreamTableEnvironment.create(env)
# Generates input data.
input_data = t_env.from_data_stream(
env.from_collection([
(Vectors.dense([1, 1]),),
(Vectors.dense([1, 4]),),
(Vectors.dense([1, 0]),),
(Vectors.dense([4, 1.5]),),
(Vectors.dense([4, 4]),),
(Vectors.dense([4, 0]),),
],
type_info=Types.ROW_NAMED(
['features'],
[DenseVectorTypeInfo()])))
# Creates an AgglomerativeClustering object and initializes its parameters.
agglomerative_clustering = AgglomerativeClustering() \
.set_linkage('ward') \
.set_distance_measure('euclidean') \
.set_prediction_col('prediction')
# Uses the AgglomerativeClustering for clustering.
outputs = agglomerative_clustering.transform(input_data)
# Extracts and display the clustering results.
field_names = outputs[0].get_schema().get_field_names()
for result in t_env.to_data_stream(outputs[0]).execute_and_collect():
features = result[field_names.index(agglomerative_clustering.features_col)]
cluster_id = result[field_names.index(agglomerative_clustering.prediction_col)]
print('Features: ' + str(features) + '\tCluster ID: ' + str(cluster_id))
# Visualizes the merge info.
merge_info = [result for result in
t_env.to_data_stream(outputs[1]).execute_and_collect()]
plt.title("Agglomerative Clustering Dendrogram")
dendrogram(merge_info)
plt.xlabel("Index of data point.")
plt.ylabel("Distances between merged clusters.")
plt.show()