This documentation is for an unreleased version of Apache Flink Machine Learning Library. We recommend you use the latest stable version.
OnlineStandardScaler #
An Estimator which implements the online standard scaling algorithm, which is the online version of StandardScaler.
OnlineStandardScaler splits the input data by the user-specified window strategy. For each window, it computes the mean and standard deviation using the data seen so far (i.e., not only the data in the current window, but also the history data). The model data generated by OnlineStandardScaler is a model stream. There is one model data for each window.
During the inference phase (i.e., using OnlineStandardScalerModel for prediction), users could output the model version that is used for predicting each data point. Moreover,
- When the train data and test data both contain event time, users could specify the maximum difference between the timestamps of the input and model data, which enforces to use a relatively fresh model for prediction.
- Otherwise, the prediction process always uses the current model data for prediction.
Input Columns #
Param name | Type | Default | Description |
---|---|---|---|
inputCol | Vector | "input" |
Features to be scaled. |
Output Columns #
Param name | Type | Default | Description |
---|---|---|---|
outputCol | Vector | "output" |
Scaled features. |
modelVersionCol | String | version |
The name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer. |
Parameters #
Below are the parameters required by OnlineStandardScalerModel
.
Key | Default | Type | Required | Description |
---|---|---|---|---|
inputCol | "input" |
String | no | Input column name. |
outputCol | "output" |
String | no | Output column name. |
withMean | false |
Boolean | no | Whether centers the data with mean before scaling. |
withStd | true |
Boolean | no | Whether scales the data with standard deviation. |
modelVersionCol | version |
String | no | The name of the column which contains the version of the model data that the input data is predicted with. The version should be a 64-bit integer. |
maxAllowedModelDelayMs | 0L |
Long | no | The maximum difference allowed between the timestamps of the input record and the model data that is used to predict that input record. This param only works when the input contains event time. |
OnlineStandardScaler
needs parameters above and also below.
Key | Default | Type | Required | Description |
---|---|---|---|---|
windows | GlobalWindows.getInstance() |
Windows | no | Windowing strategy that determines how to create mini-batches from input data. |
Examples #
import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler;
import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;
import java.util.Arrays;
import java.util.List;
/** Simple program that trains a OnlineStandardScaler model and uses it for feature engineering. */
public class OnlineStandardScalerExample {
public static void main(String[] args) {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
// Generates input data.
List<Row> inputData =
Arrays.asList(
Row.of(0L, Vectors.dense(-2.5, 9, 1)),
Row.of(1000L, Vectors.dense(1.4, -5, 1)),
Row.of(2000L, Vectors.dense(2, -1, -2)),
Row.of(6000L, Vectors.dense(0.7, 3, 1)),
Row.of(7000L, Vectors.dense(0, 1, 1)),
Row.of(8000L, Vectors.dense(0.5, 0, -2)),
Row.of(9000L, Vectors.dense(0.4, 1, 1)),
Row.of(10000L, Vectors.dense(0.3, 2, 1)),
Row.of(11000L, Vectors.dense(0.5, 1, -2)));
DataStream<Row> inputStream = env.fromCollection(inputData);
DataStream<Row> inputStreamWithEventTime =
inputStream.assignTimestampsAndWatermarks(
WatermarkStrategy.<Row>forMonotonousTimestamps()
.withTimestampAssigner(
(SerializableTimestampAssigner<Row>)
(element, recordTimestamp) ->
element.getFieldAs(0)));
Table inputTable =
tEnv.fromDataStream(
inputStreamWithEventTime,
Schema.newBuilder()
.column("f0", DataTypes.BIGINT())
.column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE))
.columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)")
.watermark("rowtime", "SOURCE_WATERMARK()")
.build())
.as("id", "input");
// Creates an OnlineStandardScaler object and initializes its parameters.
long windowSizeMs = 3000;
OnlineStandardScaler onlineStandardScaler =
new OnlineStandardScaler()
.setWindows(EventTimeTumblingWindows.of(Time.milliseconds(windowSizeMs)));
// Trains the OnlineStandardScaler Model.
OnlineStandardScalerModel model = onlineStandardScaler.fit(inputTable);
// Uses the OnlineStandardScaler Model for predictions.
Table outputTable = model.transform(inputTable)[0];
// Extracts and displays the results.
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
Row row = it.next();
DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol());
DenseVector outputValue =
(DenseVector) row.getField(onlineStandardScaler.getOutputCol());
long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol());
System.out.printf(
"Input Value: %s\tOutput Value: %s\tModel Version: %s\n",
inputValue, outputValue, modelVersion);
}
}
}
# Simple program that trains an OnlineStandardScaler model and uses it for feature
# engineering.
from pyflink.common import Types
from pyflink.common.time import Time, Instant
from pyflink.java_gateway import get_gateway
from pyflink.table import Schema
from pyflink.datastream import StreamExecutionEnvironment
from pyflink.table import StreamTableEnvironment
from pyflink.table.expressions import col
from pyflink.ml.linalg import Vectors, DenseVectorTypeInfo
from pyflink.ml.feature.onlinestandardscaler import OnlineStandardScaler
from pyflink.ml.common.window import EventTimeTumblingWindows
# Creates a new StreamExecutionEnvironment.
env = StreamExecutionEnvironment.get_execution_environment()
# Creates a StreamTableEnvironment.
t_env = StreamTableEnvironment.create(env)
# Generates input data.
dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType(
get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(),
get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer()
).getSerializerString()
schema = Schema.new_builder()
.column("ts", "TIMESTAMP_LTZ(3)")
.column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')"
.format(serializer=dense_vector_serializer))
.watermark("ts", "ts - INTERVAL '1' SECOND")
.build()
input_data = t_env.from_data_stream(
env.from_collection([
(Instant.of_epoch_milli(0), Vectors.dense(-2.5, 9, 1),),
(Instant.of_epoch_milli(1000), Vectors.dense(1.4, -5, 1),),
(Instant.of_epoch_milli(2000), Vectors.dense(2, -1, -2),),
(Instant.of_epoch_milli(6000), Vectors.dense(0.7, 3, 1),),
(Instant.of_epoch_milli(7000), Vectors.dense(0, 1, 1),),
(Instant.of_epoch_milli(8000), Vectors.dense(0.5, 0, -2),),
(Instant.of_epoch_milli(9000), Vectors.dense(0.4, 1, 1),),
(Instant.of_epoch_milli(10000), Vectors.dense(0.3, 2, 1),),
(Instant.of_epoch_milli(11000), Vectors.dense(0.5, 1, -2),)
],
type_info=Types.ROW_NAMED(
['ts', 'input'],
[Types.INSTANT(), DenseVectorTypeInfo()])),
schema)
# Creates an online standard-scaler object and initialize its parameters.
standard_scaler = OnlineStandardScaler()
.set_windows(EventTimeTumblingWindows.of(Time.milliseconds(3000)))
.set_max_allowed_model_delay_ms(0)
# Trains the online standard-scaler model.
model = standard_scaler.fit(input_data)
# Use the standard-scaler model for predictions.
output = model.transform(input_data)[0]
# extract and display the results
output = output.select(col("input"), col("output"), col("version"))
field_names = output.get_schema().get_field_names()
for result in t_env.to_data_stream(output).execute_and_collect():
input_value = result[field_names.index(standard_scaler.get_input_col())]
output_value = result[field_names.index(standard_scaler.get_output_col())]
model_version = result[field_names.index(standard_scaler.get_model_version_col())]
print('Input Value: ' + str(input_value) + ' \tOutput Value: ' + str(output_value) +
'\tModel Version: ' + str(model_version))