Naive Bayes

Naive Bayes #

Naive Bayes is a multiclass classifier. Based on Bayes’ theorem, it assumes that there is strong (naive) independence between every pair of features.

Input Columns #

Param name Type Default Description
featuresCol Vector "features" Feature vector
labelCol Integer "label" Label to predict

Output Columns #

Param name Type Default Description
predictionCol Integer "prediction" Predicted label

Parameters #

Below are parameters required by NaiveBayesModel.

Key Default Type Required Description
modelType "multinomial" String no The model type. Supported values: “multinomial”
featuresCol "features" String no Features column name.
predictionCol "prediction" String no Prediction column name.

NaiveBayes needs parameters above and also below.

Key Default Type Required Description
labelCol "label" String no Label column name.
smoothing 1.0 Double no The smoothing parameter.

Examples #

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

/** Simple program that trains a NaiveBayes model and uses it for classification. */
public class NaiveBayesExample {
    public static void main(String[] args) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        // Generates input training and prediction data.
        DataStream<Row> trainStream =
                        Row.of(Vectors.dense(0, 0.), 11),
                        Row.of(Vectors.dense(1, 0), 10),
                        Row.of(Vectors.dense(1, 1.), 10));
        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");

        DataStream<Row> predictStream =
                        Row.of(Vectors.dense(0, 1.)),
                        Row.of(Vectors.dense(0, 0.)),
                        Row.of(Vectors.dense(1, 0)),
                        Row.of(Vectors.dense(1, 1.)));
        Table predictTable = tEnv.fromDataStream(predictStream).as("features");

        // Creates a NaiveBayes object and initializes its parameters.
        NaiveBayes naiveBayes =
                new NaiveBayes()

        // Trains the NaiveBayes Model.
        NaiveBayesModel naiveBayesModel =;

        // Uses the NaiveBayes Model for predictions.
        Table outputTable = naiveBayesModel.transform(predictTable)[0];

        // Extracts and displays the results.
        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
            Row row =;
            DenseVector features = (DenseVector) row.getField(naiveBayes.getFeaturesCol());
            double predictionResult = (Double) row.getField(naiveBayes.getPredictionCol());
            System.out.printf("Features: %s \tPrediction Result: %s\n", features, predictionResult);

# Simple program that trains a NaiveBayes model and uses it for classification.
# Before executing this program, please make sure you have followed Flink ML's
# quick start guideline to set up Flink ML and Flink environment. The guideline
# can be found at

from pyflink.common import Types
from pyflink.datastream import StreamExecutionEnvironment
from import Vectors, DenseVectorTypeInfo
from import NaiveBayes
from pyflink.table import StreamTableEnvironment

# create a new StreamExecutionEnvironment
env = StreamExecutionEnvironment.get_execution_environment()

# create a StreamTableEnvironment
t_env = StreamTableEnvironment.create(env)

# generate input training and prediction data
train_table = t_env.from_data_stream(
        (Vectors.dense([0, 0.]), 11.),
        (Vectors.dense([1, 0]), 10.),
        (Vectors.dense([1, 1.]), 10.),
            ['features', 'label'],
            [DenseVectorTypeInfo(), Types.DOUBLE()])))

predict_table = t_env.from_data_stream(
        (Vectors.dense([0, 1.]),),
        (Vectors.dense([0, 0.]),),
        (Vectors.dense([1, 0]),),
        (Vectors.dense([1, 1.]),),

# create a naive bayes object and initialize its parameters
naive_bayes = NaiveBayes() \
    .set_smoothing(1.0) \
    .set_features_col('features') \
    .set_label_col('label') \
    .set_prediction_col('prediction') \

# train the naive bayes model
model =

# use the naive bayes model for predictions
output = model.transform(predict_table)[0]

# extract and display the results
field_names = output.get_schema().get_field_names()
for result in t_env.to_data_stream(output).execute_and_collect():
    features = result[field_names.index(naive_bayes.get_features_col())]
    prediction_result = result[field_names.index(naive_bayes.get_prediction_col())]
    print('Features: ' + str(features) + ' \tPrediction Result: ' + str(prediction_result))