Naive Bayes
This documentation is for an out-of-date version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.

Naive Bayes #

Naive Bayes is a multiclass classifier. Based on Bayes’ theorem, it assumes that there is strong (naive) independence between every pair of features.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector
labelCol Integer "label" Label to predict

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted label

Parameters #

Below are parameters required by NaiveBayesModel.

Key Default Type Required Description
modelType "multinomial" String no The model type. Supported values: “multinomial”
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.

NaiveBayes needs parameters above and also below.

Key Default Type Required Description
labelCol "label" String no Label column name.
smoothing 1.0 Double no The smoothing parameter.

Examples #

import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
import org.apache.flink.ml.linalg.Vectors;

List<Row> trainData =
  Arrays.asList(
  Row.of(Vectors.dense(0, 0.), 11),
  Row.of(Vectors.dense(1, 0), 10),
  Row.of(Vectors.dense(1, 1.), 10));

Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");

List<Row> predictData =
  Arrays.asList(
  Row.of(Vectors.dense(0, 1.)),
  Row.of(Vectors.dense(0, 0.)),
  Row.of(Vectors.dense(1, 0)),
  Row.of(Vectors.dense(1, 1.)));

Table predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features");

NaiveBayes estimator =
  new NaiveBayes()
  .setSmoothing(1.0)
  .setFeaturesCol("features")
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setModelType("multinomial");

NaiveBayesModel model = estimator.fit(trainTable);
Table outputTable = model.transform(predictTable)[0];

outputTable.execute().print();