################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import json
import logging
import sys
from pyflink.common import Row
from pyflink.table import (DataTypes, TableEnvironment, EnvironmentSettings)
from pyflink.table.expressions import *
from pyflink.table.udf import udtf, udf, udaf, AggregateFunction, TableAggregateFunction, udtaf
def basic_operations():
t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
# define the source
table = t_env.from_elements(
elements=[
(1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
(2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
(3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
(4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
],
schema=['id', 'data'])
right_table = t_env.from_elements(elements=[(1, 18), (2, 30), (3, 25), (4, 10)],
schema=['id', 'age'])
table = table.add_columns(
col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
.drop_columns(col('data'))
table.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany |
# | +I | 2 | hello | 135 | China |
# | +I | 3 | world | 124 | USA |
# | +I | 4 | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# limit the number of outputs
table.limit(3).execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany |
# | +I | 2 | hello | 135 | China |
# | +I | 3 | world | 124 | USA |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# filter
table.filter(col('id') != 3).execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany |
# | +I | 2 | hello | 135 | China |
# | +I | 4 | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# aggregation
table.group_by(col('country')) \
.select(col('country'), col('id').count, col('tel').cast(DataTypes.BIGINT()).max) \
.execute().print()
# +----+--------------------------------+----------------------+----------------------+
# | op | country | EXPR$0 | EXPR$1 |
# +----+--------------------------------+----------------------+----------------------+
# | +I | Germany | 1 | 123 |
# | +I | USA | 1 | 124 |
# | +I | China | 1 | 135 |
# | -U | China | 1 | 135 |
# | +U | China | 2 | 135 |
# +----+--------------------------------+----------------------+----------------------+
# distinct
table.select(col('country')).distinct() \
.execute().print()
# +----+--------------------------------+
# | op | country |
# +----+--------------------------------+
# | +I | Germany |
# | +I | China |
# | +I | USA |
# +----+--------------------------------+
# join
# Note that it still doesn't support duplicate column names between the joined tables
table.join(right_table.rename_columns(col('id').alias('r_id')), col('id') == col('r_id')) \
.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
# | op | id | name | tel | country | r_id | age |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
# | +I | 4 | PyFlink | 32 | China | 4 | 10 |
# | +I | 1 | Flink | 123 | Germany | 1 | 18 |
# | +I | 2 | hello | 135 | China | 2 | 30 |
# | +I | 3 | world | 124 | USA | 3 | 25 |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
# join lateral
@udtf(result_types=[DataTypes.STRING()])
def split(r: Row):
for s in r.name.split("i"):
yield s
table.join_lateral(split.alias('a')) \
.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | tel | country | a |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany | Fl |
# | +I | 1 | Flink | 123 | Germany | nk |
# | +I | 2 | hello | 135 | China | hello |
# | +I | 3 | world | 124 | USA | world |
# | +I | 4 | PyFlink | 32 | China | PyFl |
# | +I | 4 | PyFlink | 32 | China | nk |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# show schema
table.print_schema()
# (
# `id` BIGINT,
# `name` STRING,
# `tel` STRING,
# `country` STRING
# )
# show execute plan
print(table.join_lateral(split.alias('a')).explain())
# == Abstract Syntax Tree ==
# LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{}])
# :- LogicalProject(id=[$0], name=[JSON_VALUE($1, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], tel=[JSON_VALUE($1, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], country=[JSON_VALUE($1, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))])
# : +- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]])
# +- LogicalTableFunctionScan(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], rowType=[RecordType(VARCHAR(2147483647) a)], elementType=[class [Ljava.lang.Object;])
#
# == Optimized Physical Plan ==
# PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
# +- Calc(select=[id, JSON_VALUE(data, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS name, JSON_VALUE(data, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS tel, JSON_VALUE(data, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS country])
# +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
#
# == Optimized Execution Plan ==
# PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
# +- Calc(select=[id, JSON_VALUE(data, '$.name', NULL, ON EMPTY, NULL, ON ERROR) AS name, JSON_VALUE(data, '$.tel', NULL, ON EMPTY, NULL, ON ERROR) AS tel, JSON_VALUE(data, '$.addr.country', NULL, ON EMPTY, NULL, ON ERROR) AS country])
# +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
def sql_operations():
t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
# define the source
table = t_env.from_elements(
elements=[
(1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
(2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
(3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
(4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
],
schema=['id', 'data'])
t_env.sql_query("SELECT * FROM %s" % table) \
.execute().print()
# +----+----------------------+--------------------------------+
# | op | id | data |
# +----+----------------------+--------------------------------+
# | +I | 1 | {"name": "Flink", "tel": 12... |
# | +I | 2 | {"name": "hello", "tel": 13... |
# | +I | 3 | {"name": "world", "tel": 12... |
# | +I | 4 | {"name": "PyFlink", "tel": ... |
# +----+----------------------+--------------------------------+
# execute sql statement
@udtf(result_types=[DataTypes.STRING(), DataTypes.INT(), DataTypes.STRING()])
def parse_data(data: str):
json_data = json.loads(data)
yield json_data['name'], json_data['tel'], json_data['addr']['country']
t_env.create_temporary_function('parse_data', parse_data)
t_env.execute_sql(
"""
SELECT *
FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country)
""" % table
).print()
# +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
# | op | id | data | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
# | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany |
# | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China |
# | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA |
# | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
# explain sql plan
print(t_env.explain_sql(
"""
SELECT *
FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country)
""" % table
))
# == Abstract Syntax Tree ==
# LogicalProject(id=[$0], data=[$1], name=[$2], tel=[$3], country=[$4])
# +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}])
# :- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]])
# +- LogicalTableFunctionScan(invocation=[parse_data($cor1.data)], rowType=[RecordType:peek_no_expand(VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)])
#
# == Optimized Physical Plan ==
# PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
# +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
#
# == Optimized Execution Plan ==
# PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
# +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
def column_operations():
t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
# define the source
table = t_env.from_elements(
elements=[
(1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
(2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
(3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
(4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
],
schema=['id', 'data'])
# add columns
table = table.add_columns(
col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country'))
table.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | data | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany |
# | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China |
# | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA |
# | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# drop columns
table = table.drop_columns(col('data'))
table.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | tel | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany |
# | +I | 2 | hello | 135 | China |
# | +I | 3 | world | 124 | USA |
# | +I | 4 | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# rename columns
table = table.rename_columns(col('tel').alias('telephone'))
table.execute().print()
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | telephone | country |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1 | Flink | 123 | Germany |
# | +I | 2 | hello | 135 | China |
# | +I | 3 | world | 124 | USA |
# | +I | 4 | PyFlink | 32 | China |
# +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
# replace columns
table = table.add_or_replace_columns(
concat(col('id').cast(DataTypes.STRING()), '_', col('name')).alias('id'))
table.execute().print()
# +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | op | id | name | telephone | country |
# +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
# | +I | 1_Flink | Flink | 123 | Germany |
# | +I | 2_hello | hello | 135 | China |
# | +I | 3_world | world | 124 | USA |
# | +I | 4_PyFlink | PyFlink | 32 | China |
# +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
def row_operations():
t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
# define the source
table = t_env.from_elements(
elements=[
(1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
(2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
(3, '{"name": "world", "tel": 124, "addr": {"country": "China", "city": "NewYork"}}'),
(4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
],
schema=['id', 'data'])
# map operation
@udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
DataTypes.FIELD("country", DataTypes.STRING())]))
def extract_country(input_row: Row):
data = json.loads(input_row.data)
return Row(input_row.id, data['addr']['country'])
table.map(extract_country) \
.execute().print()
# +----+----------------------+--------------------------------+
# | op | id | country |
# +----+----------------------+--------------------------------+
# | +I | 1 | Germany |
# | +I | 2 | China |
# | +I | 3 | China |
# | +I | 4 | China |
# +----+----------------------+--------------------------------+
# flat_map operation
@udtf(result_types=[DataTypes.BIGINT(), DataTypes.STRING()])
def extract_city(input_row: Row):
data = json.loads(input_row.data)
yield input_row.id, data['addr']['city']
table.flat_map(extract_city) \
.execute().print()
# +----+----------------------+--------------------------------+
# | op | f0 | f1 |
# +----+----------------------+--------------------------------+
# | +I | 1 | Berlin |
# | +I | 2 | Shanghai |
# | +I | 3 | NewYork |
# | +I | 4 | Hangzhou |
# +----+----------------------+--------------------------------+
# aggregate operation
class CountAndSumAggregateFunction(AggregateFunction):
def get_value(self, accumulator):
return Row(accumulator[0], accumulator[1])
def create_accumulator(self):
return Row(0, 0)
def accumulate(self, accumulator, input_row):
accumulator[0] += 1
accumulator[1] += int(input_row.tel)
def retract(self, accumulator, input_row):
accumulator[0] -= 1
accumulator[1] -= int(input_row.tel)
def merge(self, accumulator, accumulators):
for other_acc in accumulators:
accumulator[0] += other_acc[0]
accumulator[1] += other_acc[1]
def get_accumulator_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("cnt", DataTypes.BIGINT()),
DataTypes.FIELD("sum", DataTypes.BIGINT())])
def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("cnt", DataTypes.BIGINT()),
DataTypes.FIELD("sum", DataTypes.BIGINT())])
count_sum = udaf(CountAndSumAggregateFunction())
table.add_columns(
col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
.group_by(col('country')) \
.aggregate(count_sum.alias("cnt", "sum")) \
.select(col('country'), col('cnt'), col('sum')) \
.execute().print()
# +----+--------------------------------+----------------------+----------------------+
# | op | country | cnt | sum |
# +----+--------------------------------+----------------------+----------------------+
# | +I | China | 3 | 291 |
# | +I | Germany | 1 | 123 |
# +----+--------------------------------+----------------------+----------------------+
# flat_aggregate operation
class Top2(TableAggregateFunction):
def emit_value(self, accumulator):
for v in accumulator:
if v:
yield Row(v)
def create_accumulator(self):
return [None, None]
def accumulate(self, accumulator, input_row):
tel = int(input_row.tel)
if accumulator[0] is None or tel > accumulator[0]:
accumulator[1] = accumulator[0]
accumulator[0] = tel
elif accumulator[1] is None or tel > accumulator[1]:
accumulator[1] = tel
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("tel", DataTypes.BIGINT())])
top2 = udtaf(Top2())
table.add_columns(
col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
.group_by(col('country')) \
.flat_aggregate(top2) \
.select(col('country'), col('tel')) \
.execute().print()
# +----+--------------------------------+----------------------+
# | op | country | tel |
# +----+--------------------------------+----------------------+
# | +I | China | 135 |
# | +I | China | 124 |
# | +I | Germany | 123 |
# +----+--------------------------------+----------------------+
if __name__ == '__main__':
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")
basic_operations()
sql_operations()
column_operations()
row_operations()