KNN
This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

KNN #

K Nearest Neighbor(KNN) is a classification algorithm. The basic assumption of KNN is that if most of the nearest K neighbors of the provided sample belongs to the same label, then it is highly probabl that the provided sample also belongs to that label.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector
labelCol Integer "label" Label to predict

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted label

Parameters #

Below are parameters required by KnnModel.

Key Default Type Required Description
K 5 Integer no The number of nearest neighbors.
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.

Knn needs parameters above and also below.

Key Default Type Required Description
labelCol "label" String no Label column name.

Examples #

import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;

List<Row> trainRows =
  new ArrayList<>(
  Arrays.asList(
    Row.of(Vectors.dense(2.0, 3.0), 1.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0),
    Row.of(Vectors.dense(200.1, 300.1), 2.0),
    Row.of(Vectors.dense(200.2, 300.2), 2.0),
    Row.of(Vectors.dense(200.3, 300.3), 2.0),
    Row.of(Vectors.dense(200.4, 300.4), 2.0),
    Row.of(Vectors.dense(200.4, 300.4), 2.0),
    Row.of(Vectors.dense(200.6, 300.6), 2.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0),
    Row.of(Vectors.dense(2.3, 3.2), 1.0),
    Row.of(Vectors.dense(2.3, 3.2), 1.0),
    Row.of(Vectors.dense(2.8, 3.2), 3.0),
    Row.of(Vectors.dense(300., 3.2), 4.0),
    Row.of(Vectors.dense(2.2, 3.2), 1.0),
    Row.of(Vectors.dense(2.4, 3.2), 5.0),
    Row.of(Vectors.dense(2.5, 3.2), 5.0),
    Row.of(Vectors.dense(2.5, 3.2), 5.0),
    Row.of(Vectors.dense(2.1, 3.1), 1.0)));
List<Row> predictRows =
  new ArrayList<>(
  Arrays.asList(
    Row.of(Vectors.dense(4.0, 4.1), 5.0),
    Row.of(Vectors.dense(300, 42), 2.0)));
Schema schema =
  Schema.newBuilder()
  .column("f0", DataTypes.of(DenseVector.class))
  .column("f1", DataTypes.DOUBLE())
  .build();

DataStream<Row> dataStream = env.fromCollection(trainRows);
Table trainData = tEnv.fromDataStream(dataStream, schema).as("features", "label");
DataStream<Row> predDataStream = env.fromCollection(predictRows);
Table predictData = tEnv.fromDataStream(predDataStream, schema).as("features", "label");

Knn knn = new Knn();
KnnModel knnModel = knn.fit(trainData);
Table output = knnModel.transform(predictData)[0];

output.execute().print();