################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import calendar
import ctypes
import datetime
import decimal
import sys
import time
from array import array
from copy import copy
from enum import Enum
from functools import reduce
from threading import RLock
from py4j.java_gateway import get_java_class
from typing import List, Union
from pyflink.common.types import _create_row
from pyflink.util.java_utils import to_jarray, is_instance_of
from pyflink.java_gateway import get_gateway
from pyflink.common import Row, RowKind
__all__ = ['DataTypes', 'UserDefinedType', 'Row', 'RowKind']
class DataType(object):
"""
Describes the data type of a value in the table ecosystem. Instances of this class can be used
to declare input and/or output types of operations.
:class:`DataType` has two responsibilities: declaring a logical type and giving hints
about the physical representation of data to the optimizer. While the logical type is mandatory,
hints are optional but useful at the edges to other APIs.
The logical type is independent of any physical representation and is close to the "data type"
terminology of the SQL standard.
Physical hints are required at the edges of the table ecosystem. Hints indicate the data format
that an implementation expects.
:param nullable: boolean, whether the type can be null (None) or not.
"""
def __init__(self, nullable=True):
self._nullable = nullable
self._conversion_cls = ''
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__, str(self._nullable).lower())
def __str__(self, *args, **kwargs):
return self.__class__.type_name()
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
def not_null(self):
cp = copy(self)
cp._nullable = False
return cp
def nullable(self):
cp = copy(self)
cp._nullable = True
return cp
@classmethod
def type_name(cls) -> str:
return cls.__name__[:-4].upper()
def bridged_to(self, conversion_cls) -> 'DataType':
"""
Adds a hint that data should be represented using the given class when entering or leaving
the table ecosystem.
:param conversion_cls: the string representation of the conversion class
"""
self._conversion_cls = conversion_cls
return self
def need_conversion(self) -> bool:
"""
Does this type need to conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MultisetType/MapType/RowType.
"""
return False
def to_sql_type(self, obj):
"""
Converts a Python object into an internal SQL object.
"""
return obj
def from_sql_type(self, obj):
"""
Converts an internal SQL object into a native Python object.
"""
return obj
class AtomicType(DataType):
"""
An internal type used to represent everything that is not
arrays, rows, and maps.
"""
def __init__(self, nullable=True):
super(AtomicType, self).__init__(nullable)
class NullType(AtomicType):
"""
Null type.
The data type representing None.
"""
def __init__(self):
super(NullType, self).__init__(True)
class NumericType(AtomicType):
"""
Numeric data types.
"""
def __init__(self, nullable=True):
super(NumericType, self).__init__(nullable)
class IntegralType(NumericType):
"""
Integral data types.
"""
def __init__(self, nullable=True):
super(IntegralType, self).__init__(nullable)
class FractionalType(NumericType):
"""
Fractional data types.
"""
def __init__(self, nullable=True):
super(FractionalType, self).__init__(nullable)
class CharType(AtomicType):
"""
Char data type. SQL CHAR(n)
The serialized string representation is ``char(n)`` where ``n`` (default: 1) is the number of
code points. ``n`` must have a value between 1 and 2147483647(0x7fffffff) (both inclusive).
:param length: int, the string representation length.
:param nullable: boolean, whether the type can be null (None) or not.
"""
def __init__(self, length=1, nullable=True):
super(CharType, self).__init__(nullable)
self.length = length
def __repr__(self):
return 'CharType(%d, %s)' % (self.length, str(self._nullable).lower())
class VarCharType(AtomicType):
"""
Varchar data type. SQL VARCHAR(n)
The serialized string representation is ``varchar(n)`` where 'n' (default: 1) is the maximum
number of code points. 'n' must have a value between 1 and 2147483647(0x7fffffff)
(both inclusive).
:param length: int, the maximum string representation length.
:param nullable: boolean, whether the type can be null (None) or not.
"""
def __init__(self, length=1, nullable=True):
super(VarCharType, self).__init__(nullable)
self.length = length
def __repr__(self):
return "VarCharType(%d, %s)" % (self.length, str(self._nullable).lower())
class BinaryType(AtomicType):
"""
Binary (byte array) data type. SQL BINARY(n)
The serialized string representation is ``binary(n)`` where ``n`` (default: 1) is the number of
bytes. ``n`` must have a value between 1 and 2147483647(0x7fffffff) (both inclusive).
:param length: int, the number of bytes.
:param nullable: boolean, whether the type can be null (None) or not.
"""
def __init__(self, length=1, nullable=True):
super(BinaryType, self).__init__(nullable)
self.length = length
def __repr__(self):
return "BinaryType(%d, %s)" % (self.length, str(self._nullable).lower())
class VarBinaryType(AtomicType):
"""
Binary (byte array) data type. SQL VARBINARY(n)
The serialized string representation is ``varbinary(n)`` where ``n`` (default: 1) is the
maximum number of bytes. ``n`` must have a value between 1 and 0x7fffffff (both inclusive).
:param length: int, the maximum number of bytes.
:param nullable: boolean, whether the type can be null (None) or not.
"""
def __init__(self, length=1, nullable=True):
super(VarBinaryType, self).__init__(nullable)
self.length = length
def __repr__(self):
return "VarBinaryType(%d, %s)" % (self.length, str(self._nullable).lower())
class BooleanType(AtomicType):
"""
Boolean data types. SQL BOOLEAN
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(BooleanType, self).__init__(nullable)
class TinyIntType(IntegralType):
"""
Byte data type. SQL TINYINT (8bits)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(TinyIntType, self).__init__(nullable)
class SmallIntType(IntegralType):
"""
Short data type. SQL SMALLINT (16bits)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(SmallIntType, self).__init__(nullable)
class IntType(IntegralType):
"""
Int data types. SQL INT (32bits)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(IntType, self).__init__(nullable)
class BigIntType(IntegralType):
"""
Long data types. SQL BIGINT (64bits)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(BigIntType, self).__init__(nullable)
class FloatType(FractionalType):
"""
Float data type. SQL FLOAT
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(FloatType, self).__init__(nullable)
class DoubleType(FractionalType):
"""
Double data type. SQL DOUBLE
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(DoubleType, self).__init__(nullable)
class DecimalType(FractionalType):
"""
Decimal (decimal.Decimal) data type.
The DecimalType must have fixed precision (the maximum total number of digits)
and scale (the number of digits on the right of dot). For example, (5, 2) can
support the value from [-999.99 to 999.99].
The precision can be up to 38, the scale must be less or equal to precision.
When create a DecimalType, the default precision and scale is (10, 0). When infer
schema from decimal.Decimal objects, it will be DecimalType(38, 18).
:param precision: the number of digits in a number (default: 10)
:param scale: the number of digits on right side of dot. (default: 0)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, precision=10, scale=0, nullable=True):
super(DecimalType, self).__init__(nullable)
assert 1 <= precision <= 38
assert 0 <= scale <= precision
self.precision = precision
self.scale = scale
self.has_precision_info = True # this is public API
def __repr__(self):
return "DecimalType(%d, %d, %s)" % (self.precision, self.scale, str(self._nullable).lower())
class DateType(AtomicType):
"""
Date data type. SQL DATE
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, nullable=True):
super(DateType, self).__init__(nullable)
EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
def need_conversion(self):
return True
def to_sql_type(self, d):
if d is not None:
return d.toordinal() - self.EPOCH_ORDINAL
def from_sql_type(self, v):
if v is not None:
return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
class TimeType(AtomicType):
"""
Time data type. SQL TIME
The precision must be greater than or equal to 0 and less than or equal to 9.
:param precision: int, the number of digits of fractional seconds (default: 0)
:param nullable: boolean, whether the field can be null (None) or not.
"""
EPOCH_ORDINAL = calendar.timegm(time.localtime(0)) * 10 ** 6
def __init__(self, precision=0, nullable=True):
super(TimeType, self).__init__(nullable)
assert 0 <= precision <= 9
self.precision = precision
def __repr__(self):
return "TimeType(%s, %s)" % (self.precision, str(self._nullable).lower())
def need_conversion(self):
return True
def to_sql_type(self, t):
if t is not None:
if t.tzinfo is not None:
offset = t.utcoffset()
offset = offset if offset else datetime.timedelta()
offset_microseconds =\
(offset.days * 86400 + offset.seconds) * 10 ** 6 + offset.microseconds
else:
offset_microseconds = self.EPOCH_ORDINAL
minutes = t.hour * 60 + t.minute
seconds = minutes * 60 + t.second
return seconds * 10 ** 6 + t.microsecond - offset_microseconds
def from_sql_type(self, t):
if t is not None:
seconds, microseconds = divmod(t + self.EPOCH_ORDINAL, 10 ** 6)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return datetime.time(hours, minutes, seconds, microseconds)
class TimestampType(AtomicType):
"""
Timestamp data type. SQL TIMESTAMP WITHOUT TIME ZONE.
Consisting of ``year-month-day hour:minute:second[.fractional]`` with up to nanosecond
precision and values ranging from ``0000-01-01 00:00:00.000000000`` to
``9999-12-31 23:59:59.999999999``. Compared to the SQL standard, leap seconds (23:59:60 and
23:59:61) are not supported.
This class does not store or represent a time-zone. Instead, it is a description of
the date, as used for birthdays, combined with the local time as seen on a wall clock.
It cannot represent an instant on the time-line without additional information
such as an offset or time-zone.
The precision must be greater than or equal to 0 and less than or equal to 9.
:param precision: int, the number of digits of fractional seconds (default: 6)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, precision=6, nullable=True):
super(TimestampType, self).__init__(nullable)
assert 0 <= precision <= 9
self.precision = precision
def __repr__(self):
return "TimestampType(%s, %s)" % (self.precision, str(self._nullable).lower())
def need_conversion(self):
return True
def to_sql_type(self, dt):
if dt is not None:
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
else time.mktime(dt.timetuple()))
return int(seconds) * 10 ** 6 + dt.microsecond
def from_sql_type(self, ts):
if ts is not None:
return datetime.datetime.fromtimestamp(ts // 10 ** 6).replace(microsecond=ts % 10 ** 6)
class LocalZonedTimestampType(AtomicType):
"""
Timestamp data type. SQL TIMESTAMP WITH LOCAL TIME ZONE.
Consisting of ``year-month-day hour:minute:second[.fractional] zone`` with up to nanosecond
precision and values ranging from ``0000-01-01 00:00:00.000000000 +14:59`` to
``9999-12-31 23:59:59.999999999 -14:59``. Compared to the SQL standard, Leap seconds (23:59:60
and 23:59:61) are not supported.
The value will be stored internally as a long value which stores all date and time
fields, to a precision of nanoseconds, as well as the offset from UTC/Greenwich.
The precision must be greater than or equal to 0 and less than or equal to 9.
:param precision: int, the number of digits of fractional seconds (default: 6)
:param nullable: boolean, whether the field can be null (None) or not.
"""
EPOCH_ORDINAL = calendar.timegm(time.localtime(0)) * 10 ** 6
def __init__(self, precision=6, nullable=True):
super(LocalZonedTimestampType, self).__init__(nullable)
assert 0 <= precision <= 9
self.precision = precision
def __repr__(self):
return "LocalZonedTimestampType(%s, %s)" % (self.precision, str(self._nullable).lower())
def need_conversion(self):
return True
def to_sql_type(self, dt):
if dt is not None:
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
else time.mktime(dt.timetuple()))
return int(seconds) * 10 ** 6 + dt.microsecond + self.EPOCH_ORDINAL
def from_sql_type(self, ts):
if ts is not None:
ts = ts - self.EPOCH_ORDINAL
return datetime.datetime.fromtimestamp(ts // 10 ** 6).replace(microsecond=ts % 10 ** 6)
class ZonedTimestampType(AtomicType):
"""
Timestamp data type with time zone. SQL TIMESTAMP WITH TIME ZONE.
Consisting of ``year-month-day hour:minute:second[.fractional] zone`` with up to nanosecond
precision and values ranging from {@code 0000-01-01 00:00:00.000000000 +14:59} to
``9999-12-31 23:59:59.999999999 -14:59``. Compared to the SQL standard, leap seconds (23:59:60
and 23:59:61) are not supported.
The value will be stored internally all date and time fields, to a precision of
nanoseconds, and a time-zone, with a zone offset used to handle ambiguous local date-times.
The precision must be greater than or equal to 0 and less than or equal to 9.
:param precision: int, the number of digits of fractional seconds (default: 6)
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, precision=6, nullable=True):
super(ZonedTimestampType, self).__init__(nullable)
assert 0 <= precision <= 9
self.precision = precision
def __repr__(self):
return "ZonedTimestampType(%s, %s)" % (self.precision, str(self._nullable).lower())
def need_conversion(self):
return True
def to_sql_type(self, dt):
if dt is not None:
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
else time.mktime(dt.timetuple()))
tzinfo = dt.tzinfo if dt.tzinfo else datetime.datetime.now(
datetime.timezone.utc).astimezone().tzinfo
offset = int(tzinfo.utcoffset(dt).total_seconds())
return int(seconds + offset) * 10 ** 6 + dt.microsecond, offset
def from_sql_type(self, zoned_ts):
if zoned_ts is not None:
from dateutil import tz
ts = zoned_ts[0] - zoned_ts[1] * 10 ** 6
tzinfo = tz.tzoffset(None, zoned_ts[1])
return datetime.datetime.fromtimestamp(ts // 10 ** 6, tz=tzinfo).replace(
microsecond=ts % 10 ** 6)
class Resolution(object):
"""
Helper class for defining the resolution of an interval.
:param unit: value defined in the constants of :class:`IntervalUnit`.
:param precision: the number of digits of years (=year precision) or the number of digits of
days (=day precision) or the number of digits of fractional seconds (
=fractional precision).
"""
class IntervalUnit(Enum):
SECOND = 0
MINUTE = 1
HOUR = 2
DAY = 3
MONTH = 4
YEAR = 5
def __init__(self, unit, precision=-1):
self._unit = unit
self._precision = precision
@property
def unit(self):
return self._unit
@property
def precision(self):
return self._precision
def __str__(self):
return '%s(%s)' % (str(self._unit), str(self._precision))
class YearMonthIntervalType(AtomicType):
"""
Year-month interval types. The type must be parameterized to one of the following
resolutions: interval of years, interval of years to months, or interval of months.
An interval of year-month consists of ``+years-months`` with values ranging from ``-9999-11``
to ``+9999-11``. The value representation is the same for all types of resolutions. For
example, an interval of months of 50 is always represented in an interval-of-years-to-months
format (with default year precision): ``+04-02``.
:param resolution: value defined in the constants of :class:`YearMonthResolution`,
representing one of the following resolutions: interval of years,
interval of years to months, or interval of months.
:param precision: int, the number of digits of years, must have a value
between 1 and 4 (both inclusive), default (2).
:param nullable: boolean, whether the field can be null (None) or not.
"""
class YearMonthResolution(object):
"""
Supported resolutions of :class:`YearMonthIntervalType`.
"""
YEAR = 1
MONTH = 2
YEAR_TO_MONTH = 3
DEFAULT_PRECISION = 2
def __init__(self, resolution, precision=DEFAULT_PRECISION, nullable=True):
assert resolution == YearMonthIntervalType.YearMonthResolution.YEAR or \
resolution == YearMonthIntervalType.YearMonthResolution.MONTH or \
resolution == YearMonthIntervalType.YearMonthResolution.YEAR_TO_MONTH
assert resolution != YearMonthIntervalType.YearMonthResolution.MONTH or \
precision == self.DEFAULT_PRECISION
assert 1 <= precision <= 4
self._resolution = resolution
self._precision = precision
super(YearMonthIntervalType, self).__init__(nullable)
@property
def resolution(self):
return self._resolution
@property
def precision(self):
return self._precision
class DayTimeIntervalType(AtomicType):
"""
Day-time interval types. The type must be parameterized to one of the following resolutions
with up to nanosecond precision: interval of days, interval of days to hours, interval of
days to minutes, interval of days to seconds, interval of hours, interval of hours to minutes,
interval of hours to seconds, interval of minutes, interval of minutes to seconds,
or interval of seconds.
An interval of day-time consists of ``+days hours:months:seconds.fractional`` with values
ranging from ``-999999 23:59:59.999999999`` to ``+999999 23:59:59.999999999``. The value
representation is the same for all types of resolutions. For example, an interval of seconds
of 70 is always represented in an interval-of-days-to-seconds format (with default precisions):
``+00 00:01:10.000000``.
:param resolution: value defined in the constants of :class:`DayTimeResolution`,
representing one of the following resolutions: interval of days, interval
of days to hours, interval of days to minutes, interval of days to seconds,
interval of hours, interval of hours to minutes, interval of hours to
seconds, interval of minutes, interval of minutes to seconds, or interval
of seconds.
:param day_precision: the number of digits of days, must have a value between 1 and 6 (both
inclusive) (default 2).
:param fractional_precision: the number of digits of fractional seconds, must have a value
between 0 and 9 (both inclusive) (default 6).
"""
class DayTimeResolution(Enum):
"""
Supported resolutions of :class:`DayTimeIntervalType`.
"""
DAY = 1
DAY_TO_HOUR = 2
DAY_TO_MINUTE = 3
DAY_TO_SECOND = 4
HOUR = 5
HOUR_TO_MINUTE = 6
HOUR_TO_SECOND = 7
MINUTE = 8
MINUTE_TO_SECOND = 9
SECOND = 10
DEFAULT_DAY_PRECISION = 2
DEFAULT_FRACTIONAL_PRECISION = 6
def __init__(self, resolution, day_precision=DEFAULT_DAY_PRECISION,
fractional_precision=DEFAULT_FRACTIONAL_PRECISION, nullable=True):
assert resolution == DayTimeIntervalType.DayTimeResolution.DAY or \
resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR or \
resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_SECOND or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND or \
resolution == DayTimeIntervalType.DayTimeResolution.MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND or \
resolution == DayTimeIntervalType.DayTimeResolution.SECOND
assert not self._needs_default_day_precision(
resolution) or day_precision == self.DEFAULT_DAY_PRECISION
assert not self._needs_default_fractional_precision(
resolution) or fractional_precision == self.DEFAULT_FRACTIONAL_PRECISION
assert 1 <= day_precision <= 6
assert 0 <= fractional_precision <= 9
self._resolution = resolution
self._day_precision = day_precision
self._fractional_precision = fractional_precision
super(DayTimeIntervalType, self).__init__(nullable)
def need_conversion(self):
return True
def to_sql_type(self, timedelta):
if timedelta is not None:
return (timedelta.days * 86400 + timedelta.seconds) * 10 ** 6 + timedelta.microseconds
def from_sql_type(self, ts):
if ts is not None:
return datetime.timedelta(microseconds=ts)
@property
def resolution(self) -> 'DayTimeIntervalType.DayTimeResolution':
return self._resolution
@property
def day_precision(self) -> int:
return self._day_precision
@property
def fractional_precision(self) -> int:
return self._fractional_precision
@staticmethod
def _needs_default_day_precision(resolution) -> bool:
if resolution == DayTimeIntervalType.DayTimeResolution.HOUR or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND or \
resolution == DayTimeIntervalType.DayTimeResolution.MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND or \
resolution == DayTimeIntervalType.DayTimeResolution.SECOND:
return True
else:
return False
@staticmethod
def _needs_default_fractional_precision(resolution) -> bool:
if resolution == DayTimeIntervalType.DayTimeResolution.DAY or \
resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR or \
resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR or \
resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE or \
resolution == DayTimeIntervalType.DayTimeResolution.MINUTE:
return True
else:
return False
_resolution_mappings = {
(Resolution.IntervalUnit.YEAR, None):
lambda p1, p2: YearMonthIntervalType(
YearMonthIntervalType.YearMonthResolution.YEAR, p1),
(Resolution.IntervalUnit.MONTH, None):
lambda p1, p2: YearMonthIntervalType(
YearMonthIntervalType.YearMonthResolution.MONTH),
(Resolution.IntervalUnit.YEAR, Resolution.IntervalUnit.MONTH):
lambda p1, p2: YearMonthIntervalType(
YearMonthIntervalType.YearMonthResolution.YEAR_TO_MONTH),
(Resolution.IntervalUnit.DAY, None):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.DAY,
p1,
DayTimeIntervalType.DEFAULT_FRACTIONAL_PRECISION),
(Resolution.IntervalUnit.DAY, Resolution.IntervalUnit.HOUR):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR,
p1,
DayTimeIntervalType.DEFAULT_FRACTIONAL_PRECISION),
(Resolution.IntervalUnit.DAY, Resolution.IntervalUnit.MINUTE):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE,
p1,
DayTimeIntervalType.DEFAULT_FRACTIONAL_PRECISION),
(Resolution.IntervalUnit.DAY, Resolution.IntervalUnit.SECOND):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.DAY_TO_SECOND, p1, p2),
(Resolution.IntervalUnit.HOUR, None):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.HOUR),
(Resolution.IntervalUnit.HOUR, Resolution.IntervalUnit.MINUTE):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE),
(Resolution.IntervalUnit.HOUR, Resolution.IntervalUnit.SECOND):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND,
DayTimeIntervalType.DEFAULT_DAY_PRECISION,
p2),
(Resolution.IntervalUnit.MINUTE, None):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.MINUTE),
(Resolution.IntervalUnit.MINUTE, Resolution.IntervalUnit.SECOND):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND,
DayTimeIntervalType.DEFAULT_DAY_PRECISION,
p2),
(Resolution.IntervalUnit.SECOND, None):
lambda p1, p2: DayTimeIntervalType(
DayTimeIntervalType.DayTimeResolution.SECOND,
DayTimeIntervalType.DEFAULT_DAY_PRECISION,
p1)
}
def _from_resolution(upper_resolution: Resolution, lower_resolution: Resolution = None):
"""
Creates an interval type (YearMonthIntervalType or DayTimeIntervalType) from the
upper_resolution and lower_resolution.
"""
lower_unit = None if lower_resolution is None else lower_resolution.unit
lower_precision = -1 if lower_resolution is None else lower_resolution.precision
interval_type_provider = _resolution_mappings[(upper_resolution.unit, lower_unit)]
if interval_type_provider is None:
raise ValueError(
"Unsupported interval definition '%s TO %s'. Please check the documentation for "
"supported combinations for year-month and day-time intervals."
% (upper_resolution, lower_resolution))
return interval_type_provider(upper_resolution.precision, lower_precision)
def _from_java_interval_type(j_interval_type):
"""
Creates an interval type from the specified Java interval type.
:param j_interval_type: the Java interval type.
:return: :class:`YearMonthIntervalType` or :class:`DayTimeIntervalType`.
"""
gateway = get_gateway()
if is_instance_of(j_interval_type, gateway.jvm.YearMonthIntervalType):
resolution = j_interval_type.getResolution()
precision = j_interval_type.getYearPrecision()
def _from_java_year_month_resolution(j_resolution):
if j_resolution == gateway.jvm.YearMonthIntervalType.YearMonthResolution.YEAR:
return YearMonthIntervalType.YearMonthResolution.YEAR
elif j_resolution == gateway.jvm.YearMonthIntervalType.YearMonthResolution.MONTH:
return YearMonthIntervalType.YearMonthResolution.MONTH
else:
return YearMonthIntervalType.YearMonthResolution.YEAR_TO_MONTH
return YearMonthIntervalType(_from_java_year_month_resolution(resolution), precision)
else:
resolution = j_interval_type.getResolution()
day_precision = j_interval_type.getDayPrecision()
fractional_precision = j_interval_type.getFractionalPrecision()
def _from_java_day_time_resolution(j_resolution):
if j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.DAY:
return DayTimeIntervalType.DayTimeResolution.DAY
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR:
return DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE:
return DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.DAY_TO_SECOND:
return DayTimeIntervalType.DayTimeResolution.DAY_TO_SECOND
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.HOUR:
return DayTimeIntervalType.DayTimeResolution.HOUR
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE:
return DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND:
return DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.MINUTE:
return DayTimeIntervalType.DayTimeResolution.MINUTE
elif j_resolution == gateway.jvm.DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND:
return DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND
else:
return DayTimeIntervalType.DayTimeResolution.SECOND
return DayTimeIntervalType(
_from_java_day_time_resolution(resolution), day_precision, fractional_precision)
_boxed_to_primitive_array_map = \
{'java.lang.Integer': '[I',
'java.lang.Long': '[J',
'java.lang.Byte': '[B',
'java.lang.Short': '[S',
'java.lang.Character': '[C',
'java.lang.Boolean': '[Z',
'java.lang.Float': '[F',
'java.lang.Double': '[D'}
class ArrayType(DataType):
"""
Array data type.
:param element_type: :class:`DataType` of each element in the array.
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, element_type, nullable=True):
"""
>>> ArrayType(VarCharType(100)) == ArrayType(VarCharType(100))
True
>>> ArrayType(VarCharType(100)) == ArrayType(BigIntType())
False
"""
assert isinstance(element_type, DataType), \
"element_type %s should be an instance of %s" % (element_type, DataType)
super(ArrayType, self).__init__(nullable)
self.element_type = element_type
def __repr__(self):
return "ArrayType(%s, %s)" % (repr(self.element_type), str(self._nullable).lower())
def need_conversion(self):
return self.element_type.need_conversion()
def to_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and [self.element_type.to_sql_type(v) for v in obj]
def from_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and [self.element_type.to_sql_type(v) for v in obj]
class ListViewType(DataType):
def __init__(self, element_type):
assert isinstance(element_type, DataType), \
"element_type %s should be an instance of %s" % (element_type, DataType)
super(ListViewType, self).__init__(False)
self._element_type = element_type
def __repr__(self):
return "ListViewType(%s)" % repr(self._element_type)
def to_sql_type(self, obj):
raise Exception("ListViewType can only be used in accumulator type declaration of "
"AggregateFunction.")
def from_sql_type(self, obj):
raise Exception("ListViewType can only be used in accumulator type declaration of "
"AggregateFunction.")
class MapType(DataType):
"""
Map data type.
:param key_type: :class:`DataType` of the keys in the map.
:param value_type: :class:`DataType` of the values in the map.
:param nullable: boolean, whether the field can be null (None) or not.
Keys in a map data type are not allowed to be null (None).
"""
def __init__(self, key_type, value_type, nullable=True):
"""
>>> (MapType(VarCharType(100, nullable=False), IntType())
... == MapType(VarCharType(100, nullable=False), IntType()))
True
>>> (MapType(VarCharType(100, nullable=False), IntType())
... == MapType(VarCharType(100, nullable=False), FloatType()))
False
"""
assert isinstance(key_type, DataType), \
"key_type %s should be an instance of %s" % (key_type, DataType)
assert isinstance(value_type, DataType), \
"value_type %s should be an instance of %s" % (value_type, DataType)
super(MapType, self).__init__(nullable)
self.key_type = key_type
self.value_type = value_type
def __repr__(self):
return "MapType(%s, %s, %s)" % (
repr(self.key_type), repr(self.value_type), str(self._nullable).lower())
def need_conversion(self):
return self.key_type.need_conversion() or self.value_type.need_conversion()
def to_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and dict((self.key_type.to_sql_type(k), self.value_type.to_sql_type(v))
for k, v in obj.items())
def from_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and dict((self.key_type.from_sql_type(k), self.value_type.from_sql_type(v))
for k, v in obj.items())
class MapViewType(DataType):
def __init__(self, key_type, value_type):
assert isinstance(key_type, DataType), \
"element_type %s should be an instance of %s" % (key_type, DataType)
assert isinstance(value_type, DataType), \
"element_type %s should be an instance of %s" % (value_type, DataType)
super(MapViewType, self).__init__(False)
self._key_type = key_type
self._value_type = value_type
def __repr__(self):
return "MapViewType(%s, %s)" % (repr(self._key_type), repr(self._value_type))
def to_sql_type(self, obj):
raise Exception("MapViewType can only be used in accumulator type declaration of "
"AggregateFunction.")
def from_sql_type(self, obj):
raise Exception("MapViewType can only be used in accumulator type declaration of "
"AggregateFunction.")
class MultisetType(DataType):
"""
MultisetType data type.
:param element_type: :class:`DataType` of each element in the multiset.
:param nullable: boolean, whether the field can be null (None) or not.
"""
def __init__(self, element_type, nullable=True):
"""
>>> MultisetType(VarCharType(100)) == MultisetType(VarCharType(100))
True
>>> MultisetType(VarCharType(100)) == MultisetType(BigIntType())
False
"""
assert isinstance(element_type, DataType), \
"element_type %s should be an instance of %s" % (element_type, DataType)
super(MultisetType, self).__init__(nullable)
self.element_type = element_type
def __repr__(self):
return "MultisetType(%s, %s)" % (repr(self.element_type), str(self._nullable).lower())
def need_conversion(self):
return self.element_type.need_conversion()
def to_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and [self.element_type.to_sql_type(v) for v in obj]
def from_sql_type(self, obj):
if not self.need_conversion():
return obj
return obj and [self.element_type.to_sql_type(v) for v in obj]
class RowField(object):
"""
A field in :class:`RowType`.
:param name: string, name of the field.
:param data_type: :class:`DataType` of the field.
:param description: string, description of the field.
"""
def __init__(self, name, data_type, description=None):
"""
>>> (RowField("f1", VarCharType(100)) == RowField("f1", VarCharType(100)))
True
>>> (RowField("f1", VarCharType(100)) == RowField("f2", VarCharType(100)))
False
"""
assert isinstance(data_type, DataType), \
"data_type %s should be an instance of %s" % (data_type, DataType)
assert isinstance(name, str), "field name %s should be string" % name
if not isinstance(name, str):
name = name.encode('utf-8')
if description is not None:
assert isinstance(description, str), \
"description %s should be string" % description
if not isinstance(description, str):
description = description.encode('utf-8')
self.name = name
self.data_type = data_type
self.description = '...' if description is None else description
def __repr__(self):
return "RowField(%s, %s, %s)" % (self.name, repr(self.data_type), self.description)
def __str__(self, *args, **kwargs):
return "RowField(%s, %s)" % (self.name, self.data_type)
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def need_conversion(self):
return self.data_type.need_conversion()
def to_sql_type(self, obj):
return self.data_type.to_sql_type(obj)
def from_sql_type(self, obj):
return self.data_type.from_sql_type(obj)
class RowType(DataType):
"""
Row type, consisting of a list of :class:`RowField`.
This is the data type representing a :class:`Row`.
Iterating a :class:`RowType` will iterate its :class:`RowField`\\s.
A contained :class:`RowField` can be accessed by name or position.
>>> row1 = RowType([RowField("f1", VarCharType(100))])
>>> row1["f1"]
RowField(f1, VarCharType(100))
>>> row1[0]
RowField(f1, VarCharType(100))
"""
def __init__(self, fields=None, nullable=True):
"""
>>> row1 = RowType([RowField("f1", VarCharType(100))])
>>> row2 = RowType([RowField("f1", VarCharType(100))])
>>> row1 == row2
True
>>> row1 = RowType([RowField("f1", VarCharType(100))])
>>> row2 = RowType([RowField("f1", VarCharType(100)), RowField("f2", IntType())])
>>> row1 == row2
False
"""
super(RowType, self).__init__(nullable)
if not fields:
self.fields = []
self.names = []
else:
self.fields = fields
self.names = [f.name for f in fields]
assert all(isinstance(f, RowField) for f in fields), \
"fields should be a list of RowField"
# Precalculated list of fields that need conversion with
# from_sql_type/to_sql_type functions
self._need_conversion = [f.need_conversion() for f in self]
self._need_serialize_any_field = any(self._need_conversion)
def add(self, field, data_type=None):
"""
Constructs a RowType by adding new elements to it to define the schema. The method accepts
either:
a) A single parameter which is a RowField object.
b) 2 parameters as (name, data_type). The data_type parameter may be either a String
or a DataType object.
>>> row1 = RowType().add("f1", VarCharType(100)).add("f2", VarCharType(100))
>>> row2 = RowType([RowField("f1", VarCharType(100)), RowField("f2", VarCharType(100))])
>>> row1 == row2
True
>>> row1 = RowType().add(RowField("f1", VarCharType(100)))
>>> row2 = RowType([RowField("f1", VarCharType(100))])
>>> row1 == row2
True
>>> row2 = RowType([RowField("f1", VarCharType(100))])
>>> row1 == row2
True
:param field: Either the name of the field or a RowField object
:param data_type: If present, the DataType of the RowField to create
:return: a new updated RowType
"""
if isinstance(field, RowField):
self.fields.append(field)
self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
raise ValueError("Must specify DataType if passing name of row_field to create.")
self.fields.append(RowField(field, data_type))
self.names.append(field)
# Precalculated list of fields that need conversion with
# from_sql_type/to_sql_type functions
self._need_conversion = [f.need_conversion() for f in self]
self._need_serialize_any_field = any(self._need_conversion)
return self
def __iter__(self):
"""
Iterate the fields.
"""
return iter(self.fields)
def __len__(self):
"""
Returns the number of fields.
"""
return len(self.fields)
def __getitem__(self, key):
"""
Accesses fields by name or slice.
"""
if isinstance(key, str):
for field in self:
if field.name == key:
return field
raise KeyError('No RowField named {0}'.format(key))
elif isinstance(key, int):
try:
return self.fields[key]
except IndexError:
raise IndexError('RowType index out of range')
elif isinstance(key, slice):
return RowType(self.fields[key])
else:
raise TypeError('RowType keys should be strings, integers or slices')
def __repr__(self):
return "RowType(%s)" % ",".join(repr(field) for field in self)
def field_names(self):
"""
Returns all field names in a list.
>>> row = RowType([RowField("f1", VarCharType(100))])
>>> row.field_names()
['f1']
"""
return list(self.names)
def field_types(self):
"""
Returns all field types in a list.
.. versionadded:: 1.11.0
"""
return list([f.data_type for f in self.fields])
def need_conversion(self):
# We need convert Row()/namedtuple into tuple()
return True
def to_sql_type(self, obj):
if obj is None:
return
if self._need_serialize_any_field:
# Only calling to_sql_type function for fields that need conversion
if isinstance(obj, dict):
return (RowKind.INSERT.value,) + tuple(
f.to_sql_type(obj.get(n)) if c else obj.get(n)
for n, f, c in zip(self.names, self.fields, self._need_conversion))
elif isinstance(obj, Row) and hasattr(obj, "_fields"):
return (obj.get_row_kind().value,) + tuple(
f.to_sql_type(obj[n]) if c else obj[n]
for n, f, c in zip(self.names, self.fields, self._need_conversion))
elif isinstance(obj, Row):
return (obj.get_row_kind().value, ) + tuple(
f.to_sql_type(v) if c else v
for f, v, c in zip(self.fields, obj, self._need_conversion))
elif isinstance(obj, (tuple, list, Row)):
return (RowKind.INSERT.value,) + tuple(
f.to_sql_type(v) if c else v
for f, v, c in zip(self.fields, obj, self._need_conversion))
elif hasattr(obj, "__dict__"):
d = obj.__dict__
return (RowKind.INSERT.value,) + tuple(
f.to_sql_type(d.get(n)) if c else d.get(n)
for n, f, c in zip(self.names, self.fields, self._need_conversion))
else:
raise ValueError("Unexpected tuple %r with RowType" % obj)
else:
if isinstance(obj, dict):
return (RowKind.INSERT.value,) + tuple(obj.get(n) for n in self.names)
elif isinstance(obj, Row) and hasattr(obj, "_fields"):
return (obj.get_row_kind().value,) + tuple(obj[n] for n in self.names)
elif isinstance(obj, Row):
return (obj.get_row_kind().value,) + tuple(obj)
elif isinstance(obj, (list, tuple)):
return (RowKind.INSERT.value,) + tuple(obj)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
return (RowKind.INSERT.value,) + tuple(d.get(n) for n in self.names)
else:
raise ValueError("Unexpected tuple %r with RowType" % obj)
def from_sql_type(self, obj):
if obj is None:
return
if isinstance(obj, Row):
# it's already converted by pickler
return obj
if self._need_serialize_any_field:
# Only calling from_sql_type function for fields that need conversion
values = [f.from_sql_type(v) if c else v
for f, v, c in zip(self.fields, obj, self._need_conversion)]
else:
values = obj
return _create_row(self.names, values)
class RawType(DataType):
"""
Logical type of pickled byte array type.
"""
def from_sql_type(self, obj):
import pickle
return pickle.loads(obj)
class UserDefinedType(DataType):
"""
User-defined type (UDT).
.. note:: WARN: Flink Internal Use Only
"""
def __eq__(self, other):
return type(self) == type(other)
@classmethod
def type_name(cls):
return cls.__name__.lower()
@classmethod
def sql_type(cls):
"""
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement sql_type().")
@classmethod
def module(cls):
"""
The Python module of the UDT.
"""
raise NotImplementedError("UDT must implement module().")
@classmethod
def java_udt(cls):
"""
The class name of the paired Java UDT (could be '', if there
is no corresponding one).
"""
return ''
def need_conversion(self):
return True
@classmethod
def _cached_sql_type(cls):
"""
Caches the sql_type() into class, because it's heavy used in `to_sql_type`.
"""
if not hasattr(cls, "__cached_sql_type"):
cls.__cached_sql_type = cls.sql_type()
return cls.__cached_sql_type
def to_sql_type(self, obj):
if obj is not None:
return self._cached_sql_type().to_sql_type(self.serialize(obj))
def from_sql_type(self, obj):
v = self._cached_sql_type().from_sql_type(obj)
if v is not None:
return self.deserialize(v)
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement serialize().")
def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
raise NotImplementedError("UDT must implement deserialize().")
# Mapping Python types to Flink SQL types
_type_mappings = {
bool: BooleanType(),
int: BigIntType(),
float: DoubleType(),
str: VarCharType(0x7fffffff),
bytearray: VarBinaryType(0x7fffffff),
decimal.Decimal: DecimalType(38, 18),
datetime.date: DateType(),
datetime.datetime: LocalZonedTimestampType(),
datetime.time: TimeType(),
}
# Mapping Python array types to Flink SQL types
# We should be careful here. The size of these types in python depends on C
# implementation. We need to make sure that this conversion does not lose any
# precision. Also, JVM only support signed types, when converting unsigned types,
# keep in mind that it requires 1 more bit when stored as singed types.
#
# Reference for C integer size, see:
# ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types <limits.h>.
# Reference for python array typecode, see:
# https://docs.python.org/2/library/array.html
# https://docs.python.org/3.6/library/array.html
# Reference for JVM's supported integral types:
# http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1
_array_signed_int_typecode_ctype_mappings = {
'b': ctypes.c_byte,
'h': ctypes.c_short,
'i': ctypes.c_int,
'l': ctypes.c_long,
}
_array_unsigned_int_typecode_ctype_mappings = {
'B': ctypes.c_ubyte,
'H': ctypes.c_ushort,
'I': ctypes.c_uint,
'L': ctypes.c_ulong
}
def _int_size_to_type(size):
"""
Returns the data type from the size of integers.
"""
if size <= 8:
return TinyIntType()
if size <= 16:
return SmallIntType()
if size <= 32:
return IntType()
if size <= 64:
return BigIntType()
# The list of all supported array typecodes is stored here
_array_type_mappings = {
# Warning: Actual properties for float and double in C is not specified in C.
# On almost every system supported by both python and JVM, they are IEEE 754
# single-precision binary floating-point format and IEEE 754 double-precision
# binary floating-point format. And we do assume the same thing here for now.
'f': FloatType(),
'd': DoubleType()
}
# compute array typecode mappings for signed integer types
for _typecode in _array_signed_int_typecode_ctype_mappings.keys():
size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode]) * 8
dt = _int_size_to_type(size)
if dt is not None:
_array_type_mappings[_typecode] = dt
# compute array typecode mappings for unsigned integer types
for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys():
# JVM does not have unsigned types, so use signed types that is at least 1
# bit larger to store
size = ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1
dt = _int_size_to_type(size)
if dt is not None:
_array_type_mappings[_typecode] = dt
# Type code 'u' in Python's array is deprecated since version 3.3, and will be
# removed in version 4.0. See: https://docs.python.org/3/library/array.html
if sys.version_info[0] < 4:
# it can be 16 bits or 32 bits depending on the platform
_array_type_mappings['u'] = CharType(ctypes.sizeof(ctypes.c_wchar)) # type: ignore
def _infer_type(obj):
"""
Infers the data type from obj.
"""
if obj is None:
return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
data_type = _type_mappings.get(type(obj))
if data_type is not None:
return data_type
if isinstance(obj, dict):
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key).not_null(), _infer_type(value))
else:
return MapType(NullType(), NullType())
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]))
else:
return ArrayType(NullType())
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode].not_null())
else:
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))
def _infer_schema(row, names=None):
"""
Infers the schema from dict/row/namedtuple/object.
"""
if isinstance(row, dict): # dict
items = sorted(row.items())
elif isinstance(row, (Row, tuple, list)):
if hasattr(row, "_fields"): # namedtuple and Row
items = zip(row._fields, tuple(row))
else:
if names is None:
names = ['_%d' % i for i in range(1, len(row) + 1)]
elif len(names) < len(row):
names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
else:
raise TypeError("Can not infer schema for type: %s" % type(row))
fields = [RowField(k, _infer_type(v)) for k, v in items]
return RowType(fields)
def _has_nulltype(dt):
"""
Returns whether there is NullType in `dt` or not.
"""
if isinstance(dt, RowType):
return any(_has_nulltype(f.data_type) for f in dt.fields)
elif isinstance(dt, ArrayType) or isinstance(dt, MultisetType):
return _has_nulltype(dt.element_type)
elif isinstance(dt, MapType):
return _has_nulltype(dt.key_type) or _has_nulltype(dt.value_type)
else:
return isinstance(dt, NullType)
def _merge_type(a, b, name=None):
if name is None:
def new_msg(msg):
return msg
def new_name(n):
return "field %s" % n
else:
def new_msg(msg):
return "%s: %s" % (name, msg)
def new_name(n):
return "field %s in %s" % (n, name)
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
# same type
if isinstance(a, RowType):
nfs = dict((f.name, f.data_type) for f in b.fields)
fields = [RowField(f.name, _merge_type(f.data_type, nfs.get(f.name, None),
name=new_name(f.name)))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
if n not in names:
fields.append(RowField(n, nfs[n]))
return RowType(fields)
elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.element_type, b.element_type,
name='element in array %s' % name))
elif isinstance(a, MultisetType):
return MultisetType(_merge_type(a.element_type, b.element_type,
name='element in multiset %s' % name))
elif isinstance(a, MapType):
return MapType(_merge_type(a.key_type, b.key_type, name='key of map %s' % name),
_merge_type(a.value_type, b.value_type, name='value of map %s' % name))
else:
return a
def _infer_schema_from_data(elements, names=None) -> RowType:
"""
Infers schema from list of Row or tuple.
:param elements: list of Row or tuple
:param names: list of column names
:return: :class:`RowType`
"""
if not elements:
raise ValueError("can not infer schema from empty data set")
schema = reduce(_merge_type, (_infer_schema(row, names) for row in elements))
if _has_nulltype(schema):
raise ValueError("Some column types cannot be determined after inferring")
return schema
def _need_converter(data_type):
if isinstance(data_type, RowType):
return True
elif isinstance(data_type, ArrayType) or isinstance(data_type, MultisetType):
return _need_converter(data_type.element_type)
elif isinstance(data_type, MapType):
return _need_converter(data_type.key_type) or _need_converter(data_type.value_type)
elif isinstance(data_type, NullType):
return True
else:
return False
def _create_converter(data_type):
"""
Creates a converter to drop the names of fields in obj.
"""
if not _need_converter(data_type):
return lambda x: x
if isinstance(data_type, ArrayType) or isinstance(data_type, MultisetType):
conv = _create_converter(data_type.element_type)
return lambda row: [conv(v) for v in row]
elif isinstance(data_type, MapType):
kconv = _create_converter(data_type.key_type)
vconv = _create_converter(data_type.value_type)
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(data_type, NullType):
return lambda x: None
elif not isinstance(data_type, RowType):
return lambda x: x
# dataType must be RowType
names = [f.name for f in data_type.fields]
converters = [_create_converter(f.data_type) for f in data_type.fields]
convert_fields = any(_need_converter(f.data_type) for f in data_type.fields)
def convert_row(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
if convert_fields:
return tuple(conv(v) for v, conv in zip(obj, converters))
else:
return tuple(obj)
if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise TypeError("Unexpected obj type: %s" % type(obj))
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])
return convert_row
_python_java_types_mapping = None
_python_java_types_mapping_lock = RLock()
_primitive_array_element_types = {BooleanType, TinyIntType, SmallIntType, IntType, BigIntType,
FloatType, DoubleType}
def _from_java_data_type(j_data_type):
"""
Converts Java DataType to Python DataType.
"""
gateway = get_gateway()
# Atomic Type with parameters.
if is_instance_of(j_data_type, gateway.jvm.AtomicDataType):
logical_type = j_data_type.getLogicalType()
if is_instance_of(logical_type, gateway.jvm.CharType):
data_type = DataTypes.CHAR(logical_type.getLength(), logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.VarCharType):
data_type = DataTypes.VARCHAR(logical_type.getLength(), logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.BinaryType):
data_type = DataTypes.BINARY(logical_type.getLength(), logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.VarBinaryType):
data_type = DataTypes.VARBINARY(logical_type.getLength(), logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.DecimalType):
data_type = DataTypes.DECIMAL(logical_type.getPrecision(),
logical_type.getScale(),
logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.DateType):
data_type = DataTypes.DATE(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.TimeType):
data_type = DataTypes.TIME(logical_type.getPrecision(), logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.TimestampType):
data_type = DataTypes.TIMESTAMP(precision=3, nullable=logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.BooleanType):
data_type = DataTypes.BOOLEAN(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.TinyIntType):
data_type = DataTypes.TINYINT(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.SmallIntType):
data_type = DataTypes.SMALLINT(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.IntType):
data_type = DataTypes.INT(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.BigIntType):
data_type = DataTypes.BIGINT(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.FloatType):
data_type = DataTypes.FLOAT(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.DoubleType):
data_type = DataTypes.DOUBLE(logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.ZonedTimestampType):
raise \
TypeError("Unsupported type: %s, ZonedTimestampType is not supported yet."
% j_data_type)
elif is_instance_of(logical_type, gateway.jvm.LocalZonedTimestampType):
data_type = DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(nullable=logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.DayTimeIntervalType) or \
is_instance_of(logical_type, gateway.jvm.YearMonthIntervalType):
data_type = _from_java_interval_type(logical_type)
elif is_instance_of(logical_type, gateway.jvm.LegacyTypeInformationType):
type_info = logical_type.getTypeInformation()
BasicArrayTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.\
BasicArrayTypeInfo
BasicTypeInfo = gateway.jvm.org.apache.flink.api.common.typeinfo.BasicTypeInfo
if type_info == BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO:
data_type = DataTypes.ARRAY(DataTypes.STRING())
elif type_info == BasicTypeInfo.BIG_DEC_TYPE_INFO:
data_type = DataTypes.DECIMAL(38, 18)
elif type_info.getClass() == \
get_java_class(gateway.jvm.org.apache.flink.table.runtime.typeutils
.BigDecimalTypeInfo):
data_type = DataTypes.DECIMAL(type_info.precision(), type_info.scale())
elif type_info.getClass() == \
get_java_class(gateway.jvm.org.apache.flink.table.dataview.ListViewTypeInfo):
data_type = DataTypes.LIST_VIEW(_from_java_data_type(type_info.getElementType()))
elif type_info.getClass() == \
get_java_class(gateway.jvm.org.apache.flink.table.dataview.MapViewTypeInfo):
data_type = DataTypes.MAP_VIEW(
_from_java_data_type(type_info.getKeyType()),
_from_java_data_type(type_info.getValueType()))
else:
raise TypeError("Unsupported type: %s, it is recognized as a legacy type."
% type_info)
elif is_instance_of(logical_type, gateway.jvm.RawType):
data_type = RawType()
else:
raise TypeError("Unsupported type: %s, it is not supported yet in current python type"
" system" % j_data_type)
return data_type
# Array Type, MultiSet Type.
elif is_instance_of(j_data_type, gateway.jvm.CollectionDataType):
logical_type = j_data_type.getLogicalType()
element_type = j_data_type.getElementDataType()
if is_instance_of(logical_type, gateway.jvm.ArrayType):
data_type = DataTypes.ARRAY(_from_java_data_type(element_type),
logical_type.isNullable())
elif is_instance_of(logical_type, gateway.jvm.MultisetType):
data_type = DataTypes.MULTISET(_from_java_data_type(element_type),
logical_type.isNullable())
else:
raise TypeError("Unsupported collection data type: %s" % j_data_type)
return data_type
# Map Type.
elif is_instance_of(j_data_type, gateway.jvm.KeyValueDataType):
logical_type = j_data_type.getLogicalType()
key_type = j_data_type.getKeyDataType()
value_type = j_data_type.getValueDataType()
if is_instance_of(logical_type, gateway.jvm.MapType):
data_type = DataTypes.MAP(
_from_java_data_type(key_type),
_from_java_data_type(value_type),
logical_type.isNullable())
else:
raise TypeError("Unsupported map data type: %s" % j_data_type)
return data_type
# Row Type.
elif is_instance_of(j_data_type, gateway.jvm.FieldsDataType):
logical_type = j_data_type.getLogicalType()
field_data_types = j_data_type.getChildren()
if is_instance_of(logical_type, gateway.jvm.RowType):
fields = [DataTypes.FIELD(name, _from_java_data_type(field_data_types[idx]))
for idx, name in enumerate(logical_type.getFieldNames())]
data_type = DataTypes.ROW(fields, logical_type.isNullable())
elif j_data_type.getConversionClass().isAssignableFrom(
gateway.jvm.org.apache.flink.table.api.dataview.ListView._java_lang_class):
array_type = _from_java_data_type(field_data_types[0])
data_type = DataTypes.LIST_VIEW(array_type.element_type)
elif j_data_type.getConversionClass().isAssignableFrom(
gateway.jvm.org.apache.flink.table.api.dataview.MapView._java_lang_class):
map_type = _from_java_data_type(field_data_types[0])
data_type = DataTypes.MAP_VIEW(map_type.key_type, map_type.value_type)
else:
raise TypeError("Unsupported row data type: %s" % j_data_type)
return data_type
# Unrecognized type.
else:
TypeError("Unsupported data type: %s" % j_data_type)
def _to_java_data_type(data_type: DataType):
"""
Converts the specified Python DataType to Java DataType.
"""
gateway = get_gateway()
JDataTypes = gateway.jvm.org.apache.flink.table.api.DataTypes
if isinstance(data_type, BooleanType):
j_data_type = JDataTypes.BOOLEAN()
elif isinstance(data_type, TinyIntType):
j_data_type = JDataTypes.TINYINT()
elif isinstance(data_type, SmallIntType):
j_data_type = JDataTypes.SMALLINT()
elif isinstance(data_type, IntType):
j_data_type = JDataTypes.INT()
elif isinstance(data_type, BigIntType):
j_data_type = JDataTypes.BIGINT()
elif isinstance(data_type, FloatType):
j_data_type = JDataTypes.FLOAT()
elif isinstance(data_type, DoubleType):
j_data_type = JDataTypes.DOUBLE()
elif isinstance(data_type, VarCharType):
j_data_type = JDataTypes.VARCHAR(data_type.length)
elif isinstance(data_type, CharType):
j_data_type = JDataTypes.CHAR(data_type.length)
elif isinstance(data_type, VarBinaryType):
j_data_type = JDataTypes.VARBINARY(data_type.length)
elif isinstance(data_type, BinaryType):
j_data_type = JDataTypes.BINARY(data_type.length)
elif isinstance(data_type, DecimalType):
j_data_type = JDataTypes.DECIMAL(data_type.precision, data_type.scale)
elif isinstance(data_type, DateType):
j_data_type = JDataTypes.DATE()
elif isinstance(data_type, TimeType):
j_data_type = JDataTypes.TIME(data_type.precision)
elif isinstance(data_type, TimestampType):
j_data_type = JDataTypes.TIMESTAMP(data_type.precision)
elif isinstance(data_type, LocalZonedTimestampType):
j_data_type = JDataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(data_type.precision)
elif isinstance(data_type, ZonedTimestampType):
j_data_type = JDataTypes.TIMESTAMP_WITH_TIME_ZONE(data_type.precision)
elif isinstance(data_type, ArrayType):
j_data_type = JDataTypes.ARRAY(_to_java_data_type(data_type.element_type))
elif isinstance(data_type, MapType):
j_data_type = JDataTypes.MAP(
_to_java_data_type(data_type.key_type),
_to_java_data_type(data_type.value_type))
elif isinstance(data_type, RowType):
fields = [JDataTypes.FIELD(f.name, _to_java_data_type(f.data_type))
for f in data_type.fields]
j_data_type = JDataTypes.ROW(to_jarray(JDataTypes.Field, fields))
elif isinstance(data_type, UserDefinedType):
if data_type.java_udt():
return gateway.jvm.org.apache.flink.util.InstantiationUtil.instantiate(
gateway.jvm.Class.forName(
data_type.java_udt(),
True,
gateway.jvm.Thread.currentThread().getContextClassLoader()))
else:
return _to_java_data_type(data_type.sql_type())
elif isinstance(data_type, MultisetType):
j_data_type = JDataTypes.MULTISET(_to_java_data_type(data_type.element_type))
elif isinstance(data_type, NullType):
j_data_type = JDataTypes.NULL()
elif isinstance(data_type, YearMonthIntervalType):
if data_type.resolution == YearMonthIntervalType.YearMonthResolution.YEAR:
j_data_type = JDataTypes.INTERVAL(JDataTypes.YEAR(data_type.precision))
elif data_type.resolution == YearMonthIntervalType.YearMonthResolution.MONTH:
j_data_type = JDataTypes.INTERVAL(JDataTypes.MONTH())
else:
j_data_type = JDataTypes.INTERVAL(JDataTypes.YEAR(data_type.precision),
JDataTypes.MONTH())
elif isinstance(data_type, DayTimeIntervalType):
if data_type.resolution == DayTimeIntervalType.DayTimeResolution.DAY:
j_data_type = JDataTypes.INTERVAL(JDataTypes.DAY(data_type.day_precision))
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_HOUR:
j_data_type = JDataTypes.INTERVAL(JDataTypes.DAY(data_type.day_precision),
JDataTypes.HOUR())
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_MINUTE:
j_data_type = JDataTypes.INTERVAL(JDataTypes.DAY(data_type.day_precision),
JDataTypes.MINUTE())
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.DAY_TO_SECOND:
j_data_type = JDataTypes.INTERVAL(JDataTypes.DAY(data_type.day_precision),
JDataTypes.SECOND(data_type.fractional_precision))
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.HOUR:
j_data_type = JDataTypes.INTERVAL(JDataTypes.HOUR())
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_MINUTE:
j_data_type = JDataTypes.INTERVAL(JDataTypes.HOUR(), JDataTypes.MINUTE())
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.HOUR_TO_SECOND:
j_data_type = JDataTypes.INTERVAL(JDataTypes.HOUR(),
JDataTypes.SECOND(data_type.fractional_precision))
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.MINUTE:
j_data_type = JDataTypes.INTERVAL(JDataTypes.MINUTE())
elif data_type.resolution == DayTimeIntervalType.DayTimeResolution.MINUTE_TO_SECOND:
j_data_type = JDataTypes.INTERVAL(JDataTypes.MINUTE(),
JDataTypes.SECOND(data_type.fractional_precision))
else:
j_data_type = JDataTypes.INTERVAL(JDataTypes.SECOND(data_type.fractional_precision))
elif isinstance(data_type, ListViewType):
return gateway.jvm.org.apache.flink.table.api.dataview.ListView.newListViewDataType(
_to_java_data_type(data_type._element_type))
elif isinstance(data_type, MapViewType):
return gateway.jvm.org.apache.flink.table.api.dataview.MapView.newMapViewDataType(
_to_java_data_type(data_type._key_type), _to_java_data_type(data_type._value_type))
else:
raise TypeError("Unsupported data type: %s" % data_type)
if data_type._nullable:
j_data_type = j_data_type.nullable()
else:
j_data_type = j_data_type.notNull()
if data_type._conversion_cls:
j_data_type = j_data_type.bridgedTo(
gateway.jvm.org.apache.flink.api.python.shaded.py4j.reflection.ReflectionUtil
.classForName(data_type._conversion_cls)
)
return j_data_type
_acceptable_types = {
BooleanType: (bool,),
TinyIntType: (int,),
SmallIntType: (int,),
IntType: (int,),
BigIntType: (int,),
FloatType: (float,),
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
CharType: (str,),
VarCharType: (str,),
BinaryType: (bytearray,),
VarBinaryType: (bytearray,),
DateType: (datetime.date, datetime.datetime),
TimeType: (datetime.time,),
TimestampType: (datetime.datetime,),
DayTimeIntervalType: (datetime.timedelta,),
LocalZonedTimestampType: (datetime.datetime,),
ZonedTimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
RowType: (tuple, list, dict),
}
def _create_type_verifier(data_type: DataType, name: str = None):
"""
Creates a verifier that checks the type of obj against data_type and raises a TypeError if they
do not match.
This verifier also checks the value of obj against data_type and raises a ValueError if it's
not within the allowed range, e.g. using 128 as TinyIntType will overflow. Note that, Python
float is not checked, so it will become infinity when cast to Java float if it overflows.
>>> _create_type_verifier(RowType([]))(None)
>>> _create_type_verifier(VarCharType(100))("")
>>> _create_type_verifier(BigIntType())(0)
>>> _create_type_verifier(ArrayType(SmallIntType()))(list(range(3)))
>>> _create_type_verifier(ArrayType(VarCharType(10)))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
>>> _create_type_verifier(MapType(VarCharType(100), IntType()))({})
>>> _create_type_verifier(RowType([]))(())
>>> _create_type_verifier(RowType([]))([])
>>> _create_type_verifier(RowType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> # Check if numeric values are within the allowed range.
>>> _create_type_verifier(TinyIntType())(12)
>>> _create_type_verifier(TinyIntType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _create_type_verifier(TinyIntType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _create_type_verifier(
... ArrayType(SmallIntType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _create_type_verifier(MapType(VarCharType(100), IntType()))({None: 1})
Traceback (most recent call last):
...
ValueError:...
>>> schema = RowType().add("a", IntType()).add("b", VarCharType(100), False)
>>> _create_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
if name is None:
new_msg = lambda msg: msg
new_name = lambda n: "field %s" % n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_name = lambda n: "field %s in %s" % (n, name)
def verify_nullability(obj):
if obj is None:
if data_type._nullable:
return True
else:
raise ValueError(new_msg("This field is not nullable, but got None"))
else:
return False
_type = type(data_type)
assert _type in _acceptable_types or isinstance(data_type, UserDefinedType),\
new_msg("unknown datatype: %s" % data_type)
def verify_acceptable_types(obj):
# subclass of them can not be from_sql_type in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError(new_msg("%s can not accept object %r in type %s"
% (data_type, obj, type(obj))))
if isinstance(data_type, CharType):
def verify_char(obj):
verify_acceptable_types(obj)
if len(obj) != data_type.length:
raise ValueError(new_msg(
"length of object (%s) of CharType is not: %d" % (obj, data_type.length)))
verify_value = verify_char
elif isinstance(data_type, VarCharType):
def verify_varchar(obj):
verify_acceptable_types(obj)
if len(obj) > data_type.length:
raise ValueError(new_msg(
"length of object (%s) of VarCharType exceeds: %d" % (obj, data_type.length)))
verify_value = verify_varchar
elif isinstance(data_type, BinaryType):
def verify_binary(obj):
verify_acceptable_types(obj)
if len(obj) != data_type.length:
raise ValueError(new_msg(
"length of object (%s) of BinaryType is not: %d" % (obj, data_type.length)))
verify_value = verify_binary
elif isinstance(data_type, VarBinaryType):
def verify_varbinary(obj):
verify_acceptable_types(obj)
if len(obj) > data_type.length:
raise ValueError(new_msg(
"length of object (%s) of VarBinaryType exceeds: %d"
% (obj, data_type.length)))
verify_value = verify_varbinary
elif isinstance(data_type, UserDefinedType):
sql_type = data_type.sql_type()
verifier = _create_type_verifier(sql_type, name=name)
def verify_udf(obj):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == data_type):
raise ValueError(new_msg("%r is not an instance of type %r" % (obj, data_type)))
data = data_type.to_sql_type(obj)
if isinstance(sql_type, RowType):
# remove the RowKind value in the first position.
data = data[1:]
verifier(data)
verify_value = verify_udf
elif isinstance(data_type, TinyIntType):
def verify_tiny_int(obj):
verify_acceptable_types(obj)
if obj < -128 or obj > 127:
raise ValueError(new_msg("object of TinyIntType out of range, got: %s" % obj))
verify_value = verify_tiny_int
elif isinstance(data_type, SmallIntType):
def verify_small_int(obj):
verify_acceptable_types(obj)
if obj < -32768 or obj > 32767:
raise ValueError(new_msg("object of SmallIntType out of range, got: %s" % obj))
verify_value = verify_small_int
elif isinstance(data_type, IntType):
def verify_integer(obj):
verify_acceptable_types(obj)
if obj < -2147483648 or obj > 2147483647:
raise ValueError(
new_msg("object of IntType out of range, got: %s" % obj))
verify_value = verify_integer
elif isinstance(data_type, ArrayType):
element_verifier = _create_type_verifier(
data_type.element_type, name="element in array %s" % name)
def verify_array(obj):
verify_acceptable_types(obj)
for i in obj:
element_verifier(i)
verify_value = verify_array
elif isinstance(data_type, MapType):
key_verifier = _create_type_verifier(data_type.key_type, name="key of map %s" % name)
value_verifier = _create_type_verifier(data_type.value_type, name="value of map %s" % name)
def verify_map(obj):
verify_acceptable_types(obj)
for k, v in obj.items():
key_verifier(k)
value_verifier(v)
verify_value = verify_map
elif isinstance(data_type, RowType):
verifiers = []
for f in data_type.fields:
verifier = _create_type_verifier(f.data_type, name=new_name(f.name))
verifiers.append((f.name, verifier))
def verify_row_field(obj):
if isinstance(obj, dict):
for f, verifier in verifiers:
verifier(obj.get(f))
elif isinstance(obj, Row) and getattr(obj, "_from_dict", False):
# the order in obj could be different than dataType.fields
for f, verifier in verifiers:
verifier(obj[f])
elif isinstance(obj, (tuple, list)):
if len(obj) != len(verifiers):
raise ValueError(
new_msg("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(verifiers))))
for v, (_, verifier) in zip(obj, verifiers):
verifier(v)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f, verifier in verifiers:
verifier(d.get(f))
else:
raise TypeError(new_msg("RowType can not accept object %r in type %s"
% (obj, type(obj))))
verify_value = verify_row_field
else:
def verify_default(obj):
verify_acceptable_types(obj)
verify_value = verify_default
def verify(obj):
if not verify_nullability(obj):
verify_value(obj)
return verify
def create_arrow_schema(field_names: List[str], field_types: List[DataType]):
"""
Create an Arrow schema with the specified filed names and types.
"""
import pyarrow as pa
fields = [pa.field(field_name, to_arrow_type(field_type), field_type._nullable)
for field_name, field_type in zip(field_names, field_types)]
return pa.schema(fields)
def from_arrow_type(arrow_type, nullable: bool = True) -> DataType:
"""
Convert Arrow type to Flink data type.
"""
from pyarrow import types
if types.is_boolean(arrow_type):
return BooleanType(nullable)
elif types.is_int8(arrow_type):
return TinyIntType(nullable)
elif types.is_int16(arrow_type):
return SmallIntType(nullable)
elif types.is_int32(arrow_type):
return IntType(nullable)
elif types.is_int64(arrow_type):
return BigIntType(nullable)
elif types.is_float32(arrow_type):
return FloatType(nullable)
elif types.is_float64(arrow_type):
return DoubleType(nullable)
elif types.is_decimal(arrow_type):
return DecimalType(arrow_type.precision, arrow_type.scale, nullable)
elif types.is_string(arrow_type):
return VarCharType(0x7fffffff, nullable)
elif types.is_fixed_size_binary(arrow_type):
return BinaryType(arrow_type.byte_width, nullable)
elif types.is_binary(arrow_type):
return VarBinaryType(0x7fffffff, nullable)
elif types.is_date32(arrow_type):
return DateType(nullable)
elif types.is_time32(arrow_type):
if str(arrow_type) == 'time32[s]':
return TimeType(0, nullable)
else:
return TimeType(3, nullable)
elif types.is_time64(arrow_type):
if str(arrow_type) == 'time64[us]':
return TimeType(6, nullable)
else:
return TimeType(9, nullable)
elif types.is_timestamp(arrow_type):
if arrow_type.unit == 's':
return TimestampType(0, nullable)
elif arrow_type.unit == 'ms':
return TimestampType(3, nullable)
elif arrow_type.unit == 'us':
return TimestampType(6, nullable)
else:
return TimestampType(9, nullable)
elif types.is_map(arrow_type):
return MapType(from_arrow_type(arrow_type.key_type),
from_arrow_type(arrow_type.item_type),
nullable)
elif types.is_list(arrow_type):
return ArrayType(from_arrow_type(arrow_type.value_type), nullable)
elif types.is_struct(arrow_type):
if any(types.is_struct(field.type) for field in arrow_type):
raise TypeError("Nested RowType is not supported in conversion from Arrow: " +
str(arrow_type))
return RowType([RowField(field.name, from_arrow_type(field.type, field.nullable))
for field in arrow_type])
elif types.is_null(arrow_type):
return NullType()
else:
raise TypeError("Unsupported data type to convert to Arrow type: " + str(dt))
def to_arrow_type(data_type: DataType):
"""
Converts the specified Flink data type to pyarrow data type.
"""
import pyarrow as pa
if isinstance(data_type, TinyIntType):
return pa.int8()
elif isinstance(data_type, SmallIntType):
return pa.int16()
elif isinstance(data_type, IntType):
return pa.int32()
elif isinstance(data_type, BigIntType):
return pa.int64()
elif isinstance(data_type, BooleanType):
return pa.bool_()
elif isinstance(data_type, FloatType):
return pa.float32()
elif isinstance(data_type, DoubleType):
return pa.float64()
elif isinstance(data_type, (CharType, VarCharType)):
return pa.utf8()
elif isinstance(data_type, BinaryType):
return pa.binary(data_type.length)
elif isinstance(data_type, VarBinaryType):
return pa.binary()
elif isinstance(data_type, DecimalType):
return pa.decimal128(data_type.precision, data_type.scale)
elif isinstance(data_type, DateType):
return pa.date32()
elif isinstance(data_type, TimeType):
if data_type.precision == 0:
return pa.time32('s')
elif 1 <= data_type.precision <= 3:
return pa.time32('ms')
elif 4 <= data_type.precision <= 6:
return pa.time64('us')
else:
return pa.time64('ns')
elif isinstance(data_type, (LocalZonedTimestampType, TimestampType)):
if data_type.precision == 0:
return pa.timestamp('s')
elif 1 <= data_type.precision <= 3:
return pa.timestamp('ms')
elif 4 <= data_type.precision <= 6:
return pa.timestamp('us')
else:
return pa.timestamp('ns')
elif isinstance(data_type, MapType):
return pa.map_(to_arrow_type(data_type.key_type), to_arrow_type(data_type.value_type))
elif isinstance(data_type, ArrayType):
if type(data_type.element_type) in [LocalZonedTimestampType, RowType]:
raise ValueError("%s is not supported to be used as the element type of ArrayType." %
data_type.element_type)
return pa.list_(to_arrow_type(data_type.element_type))
elif isinstance(data_type, RowType):
for field in data_type:
if type(field.data_type) in [LocalZonedTimestampType, RowType]:
raise TypeError("%s is not supported to be used as the field type of RowType" %
field.data_type)
fields = [pa.field(field.name, to_arrow_type(field.data_type), field.data_type._nullable)
for field in data_type]
return pa.struct(fields)
elif isinstance(data_type, NullType):
return pa.null()
else:
raise ValueError("field_type %s is not supported." % data_type)
class DataTypes(object):
"""
A :class:`DataType` can be used to declare input and/or output types of operations.
This class enumerates all supported data types of the Table & SQL API.
"""
[docs] @staticmethod
def NULL() -> NullType:
"""
Data type for representing untyped null (None) values. A null type has no
other value except null (None), thus, it can be cast to any nullable type.
This type helps in representing unknown types in API calls that use a null
(None) literal as well as bridging to formats such as JSON or Avro that
define such a type as well.
The null type is an extension to the SQL standard.
.. note:: `NullType` is still not supported yet.
"""
return NullType()
[docs] @staticmethod
def CHAR(length: int, nullable: bool = True) -> CharType:
"""
Data type of a fixed-length character string.
:param length: int, the string representation length. It must have a value
between 1 and 2147483647(0x7fffffff) (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
"""
return CharType(length, nullable)
[docs] @staticmethod
def VARCHAR(length: int, nullable: bool = True) -> VarCharType:
"""
Data type of a variable-length character string.
:param length: int, the maximum string representation length. It must have a
value between 1 and 2147483647(0x7fffffff) (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
.. note:: The length limit must be 0x7fffffff(2147483647) currently.
.. seealso:: :func:`~DataTypes.STRING`
"""
return VarCharType(length, nullable)
[docs] @staticmethod
def STRING(nullable: bool = True) -> VarCharType:
"""
Data type of a variable-length character string with defined maximum length.
This is a shortcut for ``DataTypes.VARCHAR(2147483647)``.
:param nullable: boolean, whether the type can be null (None) or not.
.. seealso:: :func:`~DataTypes.VARCHAR`
"""
return DataTypes.VARCHAR(0x7fffffff, nullable)
[docs] @staticmethod
def BOOLEAN(nullable: bool = True) -> BooleanType:
"""
Data type of a boolean with a (possibly) three-valued logic of
TRUE, FALSE, UNKNOWN.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return BooleanType(nullable)
[docs] @staticmethod
def BINARY(length: int, nullable: bool = True) -> BinaryType:
"""
Data type of a fixed-length binary string (=a sequence of bytes).
:param length: int, the number of bytes. It must have a value between
1 and 2147483647(0x7fffffff) (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
"""
return BinaryType(length, nullable)
[docs] @staticmethod
def VARBINARY(length: int, nullable: bool = True) -> VarBinaryType:
"""
Data type of a variable-length binary string (=a sequence of bytes)
:param length: int, the maximum number of bytes. It must have a value
between 1 and 2147483647(0x7fffffff) (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
.. seealso:: :func:`~DataTypes.BYTES`
"""
return VarBinaryType(length, nullable)
[docs] @staticmethod
def BYTES(nullable: bool = True) -> VarBinaryType:
"""
Data type of a variable-length binary string (=a sequence of bytes) with
defined maximum length. This is a shortcut for ``DataTypes.VARBINARY(2147483647)``.
:param nullable: boolean, whether the type can be null (None) or not.
.. seealso:: :func:`~DataTypes.VARBINARY`
"""
return DataTypes.VARBINARY(0x7fffffff, nullable)
[docs] @staticmethod
def DECIMAL(precision: int, scale: int, nullable: bool = True) -> DecimalType:
"""
Data type of a decimal number with fixed precision and scale.
:param precision: the number of digits in a number. It must have a value
between 1 and 38 (both inclusive).
:param scale: the number of digits on right side of dot. It must have
a value between 0 and precision (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
.. note:: The precision must be 38 and the scale must be 18 currently.
"""
return DecimalType(precision, scale, nullable)
[docs] @staticmethod
def TINYINT(nullable: bool = True) -> TinyIntType:
"""
Data type of a 1-byte signed integer with values from -128 to 127.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return TinyIntType(nullable)
[docs] @staticmethod
def SMALLINT(nullable: bool = True) -> SmallIntType:
"""
Data type of a 2-byte signed integer with values from -32,768 to 32,767.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return SmallIntType(nullable)
[docs] @staticmethod
def INT(nullable: bool = True) -> IntType:
"""
Data type of a 2-byte signed integer with values from -2,147,483,648
to 2,147,483,647.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return IntType(nullable)
[docs] @staticmethod
def BIGINT(nullable: bool = True) -> BigIntType:
"""
Data type of an 8-byte signed integer with values from
-9,223,372,036,854,775,808 to 9,223,372,036,854,775,807.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return BigIntType(nullable)
[docs] @staticmethod
def FLOAT(nullable: bool = True) -> FloatType:
"""
Data type of a 4-byte single precision floating point number.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return FloatType(nullable)
[docs] @staticmethod
def DOUBLE(nullable: bool = True) -> DoubleType:
"""
Data type of an 8-byte double precision floating point number.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return DoubleType(nullable)
[docs] @staticmethod
def DATE(nullable: bool = True) -> DateType:
"""
Data type of a date consisting of year-month-day with values ranging
from ``0000-01-01`` to ``9999-12-31``.
Compared to the SQL standard, the range starts at year 0000.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return DateType(nullable)
[docs] @staticmethod
def TIME(precision: int = 0, nullable: bool = True) -> TimeType:
"""
Data type of a time WITHOUT time zone.
An instance consists of hour:minute:second[.fractional with up to nanosecond
precision and values ranging from ``00:00:00.000000000`` to ``23:59:59.999999999``.
Compared to the SQL standard, leap seconds (23:59:60 and 23:59:61)
are not supported.
:param precision: int, the number of digits of fractional seconds. It must
have a value between 0 and 9 (both inclusive).
:param nullable: boolean, whether the type can be null (None) or not.
.. note:: The precision must be 0 currently.
"""
return TimeType(precision, nullable)
[docs] @staticmethod
def TIMESTAMP(precision: int = 6, nullable: bool = True) -> TimestampType:
"""
Data type of a timestamp WITHOUT time zone.
An instance consists of year-month-day hour:minute:second[.fractional
with up to nanosecond precision and values ranging from
``0000-01-01 00:00:00.000000000`` to ``9999-12-31 23:59:59.999999999``.
Compared to the SQL standard, leap seconds (``23:59:60`` and ``23:59:61``)
are not supported.
This class does not store or represent a time-zone. Instead, it is a description of
the date, as used for birthdays, combined with the local time as seen on a wall clock.
It cannot represent an instant on the time-line without additional information
such as an offset or time-zone.
:param precision: int, the number of digits of fractional seconds.
It must have a value between 0 and 9 (both inclusive). (default: 6)
:param nullable: boolean, whether the type can be null (None) or not.
.. note:: The precision must be 3 currently.
"""
return TimestampType(precision, nullable)
[docs] @staticmethod
def TIMESTAMP_WITH_LOCAL_TIME_ZONE(precision: int = 6, nullable: bool = True) \
-> LocalZonedTimestampType:
"""
Data type of a timestamp WITH LOCAL time zone.
An instance consists of year-month-day hour:minute:second[.fractional
with up to nanosecond precision and values ranging from
``0000-01-01 00:00:00.000000000 +14:59`` to ``9999-12-31 23:59:59.999999999 -14:59``.
Compared to the SQL standard, leap seconds (``23:59:60`` and ``23:59:61``)
are not supported.
The value will be stored internally as a long value which stores all date and time
fields, to a precision of nanoseconds, as well as the offset from UTC/Greenwich.
:param precision: int, the number of digits of fractional seconds.
It must have a value between 0 and 9 (both inclusive). (default: 6)
:param nullable: boolean, whether the type can be null (None) or not.
.. note:: `LocalZonedTimestampType` only supports precision of 3 currently.
"""
return LocalZonedTimestampType(precision, nullable)
[docs] @staticmethod
def TIMESTAMP_LTZ(precision: int = 6, nullable: bool = True) \
-> LocalZonedTimestampType:
"""
Data type of a timestamp WITH LOCAL time zone.
This is a shortcut for ``DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(precision, nullable)``.
:param precision: int, the number of digits of fractional seconds.
It must have a value between 0 and 9 (both inclusive). (default: 6, only
supports 3 when bridged to DataStream)
:param nullable: boolean, whether the type can be null (None) or not.
.. seealso:: :func:`~DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(precision, nullable)`
"""
return LocalZonedTimestampType(precision, nullable)
[docs] @staticmethod
def ARRAY(element_type: DataType, nullable: bool = True) -> ArrayType:
"""
Data type of an array of elements with same subtype.
Compared to the SQL standard, the maximum cardinality of an array cannot
be specified but is fixed at 2147483647(0x7fffffff). Also, any valid
type is supported as a subtype.
:param element_type: :class:`DataType` of each element in the array.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return ArrayType(element_type, nullable)
[docs] @staticmethod
def LIST_VIEW(element_type: DataType) -> ListViewType:
"""
Data type of a :class:`pyflink.table.data_view.ListView`.
It can only be used in accumulator type declaration of an Aggregate Function.
:param element_type: :class:`DataType` of each element in the list view.
"""
return ListViewType(element_type)
[docs] @staticmethod
def MAP(key_type: DataType, value_type: DataType, nullable: bool = True) -> MapType:
"""
Data type of an associative array that maps keys to values. A map
cannot contain duplicate keys; each key can map to at most one value.
There is no restriction of key types; it is the responsibility of the
user to ensure uniqueness. The map type is an extension to the SQL standard.
:param key_type: :class:`DataType` of the keys in the map.
:param value_type: :class:`DataType` of the values in the map.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return MapType(key_type, value_type, nullable)
[docs] @staticmethod
def MAP_VIEW(key_type: DataType, value_type: DataType) -> MapViewType:
"""
Data type of a :class:`pyflink.table.data_view.ListView`.
It can only be used in accumulator type declaration of an Aggregate Function.
:param key_type: :class:`DataType` of the keys in the map view.
:param value_type: :class:`DataType` of the values in the map view.
"""
return MapViewType(key_type, value_type)
[docs] @staticmethod
def MULTISET(element_type: DataType, nullable: bool = True) -> MultisetType:
"""
Data type of a multiset (=bag). Unlike a set, it allows for multiple
instances for each of its elements with a common subtype. Each unique
value is mapped to some multiplicity.
There is no restriction of element types; it is the responsibility
of the user to ensure uniqueness.
:param element_type: :class:`DataType` of each element in the multiset.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return MultisetType(element_type, nullable)
[docs] @staticmethod
def ROW(row_fields: List = [], nullable: bool = True) -> RowType:
"""
Data type of a sequence of fields. A field consists of a field name,
field type, and an optional description. The most specific type of
a row of a table is a row type. In this case, each column of the row
corresponds to the field of the row type that has the same ordinal
position as the column.
Compared to the SQL standard, an optional field description simplifies
the handling with complex structures.
:param row_fields: a list of row field types which can be created via
:func:`DataTypes.FIELD`.
:param nullable: boolean, whether the type can be null (None) or not.
"""
return RowType(row_fields, nullable)
[docs] @staticmethod
def FIELD(name: str, data_type: DataType, description: str = None) -> RowField:
"""
Field definition with field name, data type, and a description.
:param name: string, name of the field.
:param data_type: :class:`DataType` of the field.
:param description: string, description of the field.
"""
return RowField(name, data_type, description)
[docs] @staticmethod
def SECOND(precision: int = DayTimeIntervalType.DEFAULT_FRACTIONAL_PRECISION) -> Resolution:
"""
Resolution in seconds and (possibly) fractional seconds.
:param precision: int, the number of digits of fractional seconds. It must have a value
between 0 and 9 (both inclusive), (default: 6).
:return: the specified :class:`Resolution`.
.. note:: the precision must be 3 currently.
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.SECOND, precision)
[docs] @staticmethod
def MINUTE() -> Resolution:
"""
Resolution in minutes.
:return: the specified :class:`Resolution`.
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.MINUTE)
[docs] @staticmethod
def HOUR() -> Resolution:
"""
Resolution in hours.
:return: :class:`Resolution`
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.HOUR)
[docs] @staticmethod
def DAY(precision: int = DayTimeIntervalType.DEFAULT_DAY_PRECISION) -> Resolution:
"""
Resolution in days.
:param precision: int, the number of digits of days. It must have a value between 1 and
6 (both inclusive), (default: 2).
:return: the specified :class:`Resolution`.
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.DAY, precision)
[docs] @staticmethod
def MONTH() -> Resolution:
"""
Resolution in months.
:return: the specified :class:`Resolution`.
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.MONTH)
[docs] @staticmethod
def YEAR(precision: int = YearMonthIntervalType.DEFAULT_PRECISION) -> Resolution:
"""
Resolution in years with 2 digits for the number of years by default.
:param precision: the number of digits of years. It must have a value between 1 and
4 (both inclusive), (default 2).
:return: the specified :class:`Resolution`.
.. seealso:: :func:`~pyflink.table.DataTypes.INTERVAL`
"""
return Resolution(Resolution.IntervalUnit.YEAR, precision)
[docs] @staticmethod
def INTERVAL(upper_resolution: Resolution, lower_resolution: Resolution = None) \
-> Union[DayTimeIntervalType, YearMonthIntervalType]:
"""
Data type of a temporal interval. There are two types of temporal intervals: day-time
intervals with up to nanosecond granularity or year-month intervals with up to month
granularity.
An interval of day-time consists of ``+days hours:months:seconds.fractional`` with values
ranging from ``-999999 23:59:59.999999999`` to ``+999999 23:59:59.999999999``. The type
must be parameterized to one of the following resolutions: interval of days, interval of
days to hours, interval of days to minutes, interval of days to seconds, interval of hours,
interval of hours to minutes, interval of hours to seconds, interval of minutes,
interval of minutes to seconds, or interval of seconds. The value representation is the
same for all types of resolutions. For example, an interval of seconds of 70 is always
represented in an interval-of-days-to-seconds format (with default precisions):
``+00 00:01:10.000000``.
An interval of year-month consists of ``+years-months`` with values ranging from
``-9999-11`` to ``+9999-11``. The type must be parameterized to one of the following
resolutions: interval of years, interval of years to months, or interval of months. The
value representation is the same for all types of resolutions. For example, an interval
of months of 50 is always represented in an interval-of-years-to-months format (with
default year precision): ``+04-02``.
Examples: ``INTERVAL(DAY(2), SECOND(9))`` for a day-time interval or
``INTERVAL(YEAR(4), MONTH())`` for a year-month interval.
:param upper_resolution: :class:`Resolution`, the upper resolution of the interval.
:param lower_resolution: :class:`Resolution`, the lower resolution of the interval.
.. note:: the upper_resolution must be `MONTH` for `YearMonthIntervalType`, `SECOND` for
`DayTimeIntervalType` and the lower_resolution must be None currently.
.. seealso:: :func:`~pyflink.table.DataTypes.SECOND`
.. seealso:: :func:`~pyflink.table.DataTypes.MINUTE`
.. seealso:: :func:`~pyflink.table.DataTypes.HOUR`
.. seealso:: :func:`~pyflink.table.DataTypes.DAY`
.. seealso:: :func:`~pyflink.table.DataTypes.MONTH`
.. seealso:: :func:`~pyflink.table.DataTypes.YEAR`
"""
return _from_resolution(upper_resolution, lower_resolution)