This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.
KNN
KNN #
K Nearest Neighbor(KNN) is a classification algorithm. The basic assumption of KNN is that if most of the nearest K neighbors of the provided sample belongs to the same label, then it is highly probabl that the provided sample also belongs to that label.
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
featuresCol | Vector | "features" |
Feature vector |
labelCol | Integer | "label" |
Label to predict |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
predictionCol | Integer | "prediction" |
Predicted label |
Parameters #
Below are parameters required by KnnModel
.
Key | Default | Type | Required | Description |
---|---|---|---|---|
K | 5 |
Integer | no | The number of nearest neighbors. |
featuresCol | "features" |
String | no | Features column name. |
predictionCol | "prediction" |
String | no | Prediction column name. |
Knn
needs parameters above and also below.
Key | Default | Type | Required | Description |
---|---|---|---|---|
labelCol | "label" |
String | no | Label column name. |
Examples #
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
List<Row> trainRows =
new ArrayList<>(
Arrays.asList(
Row.of(Vectors.dense(2.0, 3.0), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(200.1, 300.1), 2.0),
Row.of(Vectors.dense(200.2, 300.2), 2.0),
Row.of(Vectors.dense(200.3, 300.3), 2.0),
Row.of(Vectors.dense(200.4, 300.4), 2.0),
Row.of(Vectors.dense(200.4, 300.4), 2.0),
Row.of(Vectors.dense(200.6, 300.6), 2.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.3, 3.2), 1.0),
Row.of(Vectors.dense(2.3, 3.2), 1.0),
Row.of(Vectors.dense(2.8, 3.2), 3.0),
Row.of(Vectors.dense(300., 3.2), 4.0),
Row.of(Vectors.dense(2.2, 3.2), 1.0),
Row.of(Vectors.dense(2.4, 3.2), 5.0),
Row.of(Vectors.dense(2.5, 3.2), 5.0),
Row.of(Vectors.dense(2.5, 3.2), 5.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0)));
List<Row> predictRows =
new ArrayList<>(
Arrays.asList(
Row.of(Vectors.dense(4.0, 4.1), 5.0),
Row.of(Vectors.dense(300, 42), 2.0)));
Schema schema =
Schema.newBuilder()
.column("f0", DataTypes.of(DenseVector.class))
.column("f1", DataTypes.DOUBLE())
.build();
DataStream<Row> dataStream = env.fromCollection(trainRows);
Table trainData = tEnv.fromDataStream(dataStream, schema).as("features", "label");
DataStream<Row> predDataStream = env.fromCollection(predictRows);
Table predictData = tEnv.fromDataStream(predDataStream, schema).as("features", "label");
Knn knn = new Knn();
KnnModel knnModel = knn.fit(trainData);
Table output = knnModel.transform(predictData)[0];
output.execute().print();