K Nearest Neighbor(KNN) is a classification algorithm. The basic assumption of
KNN is that if most of the nearest K neighbors of the provided sample belongs to
the same label, then it is highly probabl that the provided sample also belongs
to that label.
import org.apache.flink.ml.classification.knn.Knn;
import org.apache.flink.ml.classification.knn.KnnModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
List<Row> trainRows =
new ArrayList<>(
Arrays.asList(
Row.of(Vectors.dense(2.0, 3.0), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(200.1, 300.1), 2.0),
Row.of(Vectors.dense(200.2, 300.2), 2.0),
Row.of(Vectors.dense(200.3, 300.3), 2.0),
Row.of(Vectors.dense(200.4, 300.4), 2.0),
Row.of(Vectors.dense(200.4, 300.4), 2.0),
Row.of(Vectors.dense(200.6, 300.6), 2.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0),
Row.of(Vectors.dense(2.3, 3.2), 1.0),
Row.of(Vectors.dense(2.3, 3.2), 1.0),
Row.of(Vectors.dense(2.8, 3.2), 3.0),
Row.of(Vectors.dense(300., 3.2), 4.0),
Row.of(Vectors.dense(2.2, 3.2), 1.0),
Row.of(Vectors.dense(2.4, 3.2), 5.0),
Row.of(Vectors.dense(2.5, 3.2), 5.0),
Row.of(Vectors.dense(2.5, 3.2), 5.0),
Row.of(Vectors.dense(2.1, 3.1), 1.0)));
List<Row> predictRows =
new ArrayList<>(
Arrays.asList(
Row.of(Vectors.dense(4.0, 4.1), 5.0),
Row.of(Vectors.dense(300, 42), 2.0)));
Schema schema =
Schema.newBuilder()
.column("f0", DataTypes.of(DenseVector.class))
.column("f1", DataTypes.DOUBLE())
.build();
DataStream<Row> dataStream = env.fromCollection(trainRows);
Table trainData = tEnv.fromDataStream(dataStream, schema).as("features", "label");
DataStream<Row> predDataStream = env.fromCollection(predictRows);
Table predictData = tEnv.fromDataStream(predDataStream, schema).as("features", "label");
Knn knn = new Knn();
KnnModel knnModel = knn.fit(trainData);
Table output = knnModel.transform(predictData)[0];
output.execute().print();