OnlineStandardScaler

OnlineStandardScaler #

An Estimator which implements the online standard scaling algorithm, which is the online version of StandardScaler.

OnlineStandardScaler splits the input data by the user-specified window strategy. For each window, it computes the mean and standard deviation using the data seen so far (i.e., not only the data in the current window, but also the history data). The model data generated by OnlineStandardScaler is a model stream. There is one model data for each window.

During the inference phase (i.e., using OnlineStandardScalerModel for prediction), users could output the model version that is used for predicting each data point. Moreover,

  • When the train data and test data both contain event time, users could specify the maximum difference between the timestamps of the input and model data, which enforces to use a relatively fresh model for prediction.
  • Otherwise, the prediction process always uses the current model data for prediction.

Input Columns #

Param name Type Default Description
inputCol Vector "input" Features to be scaled.

Output Columns #

Param name Type Default Description
outputCol Vector "output" Scaled features.
modelVersionCol String version The name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer.

Parameters #

Below are the parameters required by OnlineStandardScalerModel.

Key Default Type Required Description
inputCol "input" String no Input column name.
outputCol "output" String no Output column name.
withMean false Boolean no Whether centers the data with mean before scaling.
withStd true Boolean no Whether scales the data with standard deviation.
modelVersionCol version String no The name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer.
maxAllowedModelDelayMs 0L Long no The maximum difference allowed between the timestamps of the input record and the model data that is used to predict that input record. This param only works when the input contains event time.

OnlineStandardScaler needs parameters above and also below.

Key Default Type Required Description
windows GlobalWindows.getInstance() Windows no Windowing strategy that determines how to create mini-batches from input data.

Examples #

import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.Arrays;
import java.util.List;

/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */
public class OnlineStandardScalerExample {
	public static void main(String[] args) {
		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

		// Generates input data.
		List<Row> inputData =
			Arrays.asList(
				Row.of(0L, Vectors.dense(-2.5, 9, 1)),
				Row.of(1000L, Vectors.dense(1.4, -5, 1)),
				Row.of(2000L, Vectors.dense(2, -1, -2)),
				Row.of(6000L, Vectors.dense(0.7, 3, 1)),
				Row.of(7000L, Vectors.dense(0, 1, 1)),
				Row.of(8000L, Vectors.dense(0.5, 0, -2)),
				Row.of(9000L, Vectors.dense(0.4, 1, 1)),
				Row.of(10000L, Vectors.dense(0.3, 2, 1)),
				Row.of(11000L, Vectors.dense(0.5, 1, -2)));

		DataStream<Row> inputStream = env.fromCollection(inputData);

		DataStream<Row> inputStreamWithEventTime =
			inputStream.assignTimestampsAndWatermarks(
				WatermarkStrategy.<Row>forMonotonousTimestamps()
					.withTimestampAssigner(
						(SerializableTimestampAssigner<Row>)
							(element, recordTimestamp) ->
								element.getFieldAs(0)));

		Table inputTable =
			tEnv.fromDataStream(
					inputStreamWithEventTime,
					Schema.newBuilder()
						.column("f0", DataTypes.BIGINT())
						.column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE))
						.columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)")
						.watermark("rowtime", "SOURCE_WATERMARK()")
						.build())
				.as("id", "input");

		// Creates an OnlineStandardScaler object and initializes its parameters.
		long windowSizeMs = 3000;
		OnlineStandardScaler onlineStandardScaler =
			new OnlineStandardScaler()
				.setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs)));

		// Trains the OnlineStandardScaler Model.
		OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable);

		// Uses the OnlineStandardScaler Model for predictions.
		Table outputTable = model.transform(inputTable)[0];

		// Extracts and displays the results.
		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
			Row row = it.next();
			DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol());
			DenseVector outputValue =
				(DenseVector) row.getField(onlineStandardScaler.getOutputCol());
			long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol());
			System.out.printf(
				"Input Value: %s\tOutput Value: %s\tModel Version: %s\n",
				inputValue, outputValue, modelVersion);
		}
	}
}

# Simple program that trains an OnlineStandardScaler model and uses it for feature
# engineering.

from pyflink.common import Types
from pyflink.common.time import Time, Instant
from pyflink.java_gateway import get_gateway
from pyflink.table import Schema
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.table import StreamTableEnvironment
from pyflink.table.expressions import col

from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.feature.onlinestandardscaler import OnlineStandardScaler
from pyflink.ml.common.window import EventTimeTumblingWindows

# Creates a new StreamExecutionEnvironment.
env = StreamExecutionEnvironment.get_execution_environment()

# Creates a StreamTableEnvironment.
t_env = StreamTableEnvironment.create(env)

# Generates input data.
dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType(
    get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(),
    get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer()
).getSerializerString()

schema = Schema.new_builder()
    .column("ts", "TIMESTAMP_LTZ(3)")
    .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')"
            .format(serializer=dense_vector_serializer))
    .watermark("ts", "ts - INTERVAL '1' SECOND")
    .build()

input_data = t_env.from_data_stream(
    env.from_collection([
        (Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),),
        (Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),),
        (Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),),
        (Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),),
        (Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),),
        (Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),),
        (Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),),
        (Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),),
        (Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),)
    ],
        type_info=Types.ROW_NAMED(
            ['ts', 'input'],
            [Types.INSTANT(), DenseVectorTypeInfo()])),
    schema)

# Creates an online standard-scaler object and initialize its parameters.
standard_scaler = OnlineStandardScaler()
    .set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000)))
    .set_max_allowed_model_delay_ms(0)

# Trains the online standard-scaler model.
model = standard_scaler.fit(input_data)

# Use the standard-scaler model for predictions.
output = model.transform(input_data)[0]

# extract and display the results
output = output.select(col("input"), col("output"), col("version"))
field_names = output.get_schema().get_field_names()

for result in t_env.to_data_stream(output).execute_and_collect():
    input_value = result[field_names.index(standard_scaler.get_input_col())]
    output_value = result[field_names.index(standard_scaler.get_output_col())]
    model_version = result[field_names.index(standard_scaler.get_model_version_col())]
    print('Input Value: ' + str(input_value) + ' \tOutput Value: ' + str(output_value) +
          '\tModel Version: ' + str(model_version))